Source code for hydra_zen.wrapper._implementations

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
# pyright: strict, reportUnnecessaryTypeIgnoreComment = true, reportUnnecessaryIsInstance = false

import warnings
from collections import defaultdict
from contextvars import copy_context
from copy import deepcopy
from functools import partial, wraps
from inspect import Parameter, iscoroutinefunction, signature
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    DefaultDict,
    Dict,
    FrozenSet,
    Generator,
    Generic,
    Iterable,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
    overload,
)

import hydra
from hydra.conf import HydraConf
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing_extensions import (
    Final,
    Literal,
    ParamSpec,
    Protocol,
    Self,
    TypeAlias,
    TypedDict,
    TypeGuard,
)

from hydra_zen import instantiate
from hydra_zen._compatibility import HYDRA_VERSION, Version
from hydra_zen.errors import HydraZenValidationError
from hydra_zen.structured_configs._implementations import DefaultBuilds
from hydra_zen.structured_configs._type_guards import safe_getattr
from hydra_zen.typing._implementations import (
    DataClass_,
    GroupName,
    Node,
    NodeName,
    StoreEntry,
)

from ..structured_configs._type_guards import is_dataclass
from ..structured_configs._utils import safe_name

if TYPE_CHECKING:
    from hydra_zen import BuildsFn


__all__ = ["zen", "store", "Zen"]


R = TypeVar("R")
P = ParamSpec("P")
P2 = ParamSpec("P2")
R2 = TypeVar("R2")
F = TypeVar("F")
F2 = TypeVar("F2", bound=Callable[..., Any])


_UNSPECIFIED_: Any = object()


_SUPPORTED_INSTANTIATION_TYPES: Tuple[Any, ...] = (dict, DictConfig, list, ListConfig)

ConfigLike: TypeAlias = Union[
    DataClass_,
    Type[DataClass_],
    Dict[Any, Any],
    DictConfig,
]


def is_instantiable(
    cfg: Any,
) -> TypeGuard[ConfigLike]:
    return is_dataclass(cfg) or isinstance(cfg, _SUPPORTED_INSTANTIATION_TYPES)


SKIPPED_PARAM_KINDS = frozenset(
    (Parameter.POSITIONAL_ONLY, Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL)
)


PreCall = Optional[Union[Callable[[Any], Any], Iterable[Callable[[Any], Any]]]]


def _flat_call(x: Iterable[Callable[P, Any]]) -> Callable[P, None]:
    def f(*args: P.args, **kwargs: P.kwargs) -> None:
        for fn in x:
            fn(*args, **kwargs)

    return f


def _identity(x: F2) -> F2:
    return x


[docs] class Zen(Generic[P, R]): """Implements the wrapping logic that is exposed by `hydra_zen.zen` Attributes ---------- func : Callable[Sig, R] The function that was wrapped. CFG_NAME : str The reserved parameter name specifies to pass the input config through to the inner function. Can be overwritted via subclassing. Defaults to 'zen_cfg' See Also -------- zen : A decorator that returns a function that will auto-extract, resolve, and instantiate fields from an input config based on the decorated function's signature. """ # Specifies reserved parameter name specified to pass the # config through to the task function CFG_NAME: str = "zen_cfg" def __repr__(self) -> str: return f"zen[{(safe_name(self.func))}({', '.join(self.parameters)})](cfg, /)"
[docs] def __init__( self, __func: Callable[P, R], *, exclude: Optional[Union[str, Iterable[str]]] = None, pre_call: PreCall = None, unpack_kwargs: bool = False, resolve_pre_call: bool = True, run_in_context: bool = False, instantiation_wrapper: Union[None, Callable[[F2], F2]] = None, ) -> None: """ Parameters ---------- func : Callable[Sig, R], positional-only The function being wrapped. unpack_kwargs: bool, optional (default=False) If `True` a `**kwargs` field in the wrapped function's signature will be populated by all of the input config entries that are not specified by the rest of the signature (and that are not specified by the `exclude` argument). pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] One or more functions that will be called with the input config prior to the wrapped function. An iterable of pre-call functions are called from left (low-index) to right (high-index). This is useful, e.g., for seeding a RNG prior to the instantiation phase that is triggered when calling the wrapped function. resolve_pre_call : bool, (default=True) If `True`, the config passed to the zen-wrapped function has its interpolated fields resolved prior to being passed to any pre-call functions. Otherwise, the interpolation occurs after the pre-call functions are called. exclude : Optional[str | Iterable[str]] Specifies one or more parameter names in the function's signature that will not be extracted from input configs by the zen-wrapped function. A single string of comma-separated names can be specified. run_in_context : bool, optional (default=False) If `True`, the zen-wrapped function - and the `pre_call` function, if specified - is run in a copied :py:class:`contextvars.Context`; i.e. changes made to any :py:class:`contextvars.ContextVar` will be isolated to that call of the wrapped function. `run_in_context` is not supported for async functions. instantiation_wrapper : Optional[Callable[[F2], F2]], optional (default=None) If specified, a function that wraps the task function and all instantiation-targets before they are called. This can be used to introduce a layer of validation or logging to all instantiation calls in your application. """ if run_in_context and iscoroutinefunction(__func): raise TypeError(f"`{run_in_context=} is not supported for async functions.") self.func: Callable[P, R] = __func try: # Must cast to dict so that `self` is pickle-compatible. self.parameters: Mapping[str, Parameter] = dict( signature(self.func).parameters ) except (ValueError, TypeError): raise HydraZenValidationError( "hydra_zen.zen can only wrap callables that possess inspectable " "signatures." ) if not isinstance(unpack_kwargs, bool): raise TypeError(f"`unpack_kwargs` must be type `bool` got {unpack_kwargs}") if not isinstance(resolve_pre_call, bool): # pragma: no cover raise TypeError( f"`resolve_pre_call` must be type `bool` got {resolve_pre_call}" ) if not isinstance(run_in_context, bool): # pragma: no cover raise TypeError( f"`run_in_context` must be type `bool` got {run_in_context}" ) self._instantiation_wrapper = instantiation_wrapper self._resolve = resolve_pre_call self._unpack_kwargs: bool = unpack_kwargs and any( p.kind is p.VAR_KEYWORD for p in self.parameters.values() ) self._run_in_context: bool = run_in_context self._exclude: Set[str] if exclude is None: self._exclude = set() elif isinstance(exclude, str): self._exclude = {k.strip() for k in exclude.split(",")} else: self._exclude = set(exclude) if self.CFG_NAME in self.parameters: self._has_zen_cfg = True self.parameters = { name: param for name, param in self.parameters.items() if name != self.CFG_NAME } else: self._has_zen_cfg = False self._pre_call_iterable = ( (pre_call,) if not isinstance(pre_call, Iterable) else pre_call ) # validate pre-call signatures for _f in self._pre_call_iterable: if _f is None: continue if run_in_context and isinstance(_f, Zen) and _f._run_in_context: raise HydraZenValidationError( f"zen-wrapped pre_call function {_f!r} cannot specify " f"`run_in_context=True` when the main wrapper specifies it as well." ) _f_params = signature(_f).parameters # type: ignore if (sum(p.default is p.empty for p in _f_params.values()) > 1) or len( _f_params ) == 0: raise HydraZenValidationError( f"pre_call function {_f} must be able to accept a single " "positional argument" ) self.pre_call: Optional[Callable[[Any], Any]] = ( pre_call if not isinstance(pre_call, Iterable) else _flat_call(pre_call) )
def _normalize_cfg( self, cfg: Union[ DataClass_, Type[DataClass_], Dict[Any, Any], List[Any], ListConfig, DictConfig, str, ], ) -> DictConfig: if is_dataclass(cfg): # ensures that default factories and interpolated fields # are resolved cfg = OmegaConf.structured(cfg) elif not OmegaConf.is_config(cfg): if not isinstance(cfg, (dict, str)): raise HydraZenValidationError( f"`cfg` must be a dataclass, dict/DictConfig, or " f"dict-style yaml-string. Got {cfg}" ) cfg = OmegaConf.create(cfg) if not isinstance(cfg, DictConfig): raise HydraZenValidationError( f"`cfg` must be a dataclass, dict/DictConfig, or " f"dict-style yaml-string. Got {cfg}" ) return cfg
[docs] def validate(self, __cfg: Union[ConfigLike, str]) -> None: """Validates the input config based on the decorated function without calling said function. Parameters ---------- cfg : dict | list | DataClass | Type[DataClass] | str (positional only) A config object or yaml-string whose attributes will be checked according to the signature of `func`. Raises ------ HydraValidationError `cfg` is not a valid input to the zen-wrapped function. """ for _f in self._pre_call_iterable: if isinstance(_f, Zen): _f.validate(__cfg) cfg = self._normalize_cfg(__cfg) num_pos_only = sum( p.kind is p.POSITIONAL_ONLY for p in self.parameters.values() ) _args_: List[Any] = getattr(cfg, "_args_", []) if not isinstance(_args_, Sequence): raise HydraZenValidationError( f"`cfg._args_` must be a sequence type (e.g. a list), got {_args_}" ) if num_pos_only and len(_args_) != num_pos_only: raise HydraZenValidationError( f"{self.func} has {num_pos_only} positional-only arguments, but " f"`cfg` specifies {len(getattr(cfg, '_args_', []))} positional " f"arguments via `_args_`." ) missing_params: List[str] = [] for name, param in self.parameters.items(): if name in self._exclude: continue if param.kind in SKIPPED_PARAM_KINDS: continue if not hasattr(cfg, name) and param.default is param.empty: missing_params.append(name) if missing_params: raise HydraZenValidationError( f"`cfg` is missing the following fields: {', '.join(missing_params)}" )
def instantiate(self, __c: Any) -> Any: """Instantiates each config that is extracted by `zen` before calling the wrapped function. Overwrite this to change `ZenWrapper`'s instantiation behavior.""" __c = instantiate(__c, _target_wrapper_=self._instantiation_wrapper) if isinstance(__c, (ListConfig, DictConfig)): return OmegaConf.to_object(__c) else: return __c # TODO: add "extract" option that enables returning dict of fields
[docs] def __call__(self, __cfg: Union[ConfigLike, str]) -> R: """ Extracts values from the input config based on the decorated function's signature, resolves & instantiates them, and calls the function with them. Parameters ---------- cfg : dict | DataClass | Type[DataClass] | str (positional only) A config object or yaml-string whose attributes will be extracted by-name according to the signature of `func` and passed to `func`. Attributes of types that can be instantiated by Hydra will be instantiated prior to being passed to `func`. Returns ------- func_out : R The result of `func(<args extracted from cfg>)` """ cfg = self._normalize_cfg(__cfg) if self._resolve: # resolves all interpolated values in-place OmegaConf.resolve(cfg) context = copy_context() if self._run_in_context else None if self.pre_call is not None: pre_call = ( self.pre_call if context is None else partial(context.run, self.pre_call) ) pre_call(cfg) args_ = list(getattr(cfg, "_args_", [])) cfg_kwargs = { name: ( safe_getattr(cfg, name, param.default) if param.default is not param.empty else safe_getattr(cfg, name) ) for name, param in self.parameters.items() if param.kind not in SKIPPED_PARAM_KINDS and name not in self._exclude } extra_kwargs = {self.CFG_NAME: cfg} if self._has_zen_cfg else {} if self._unpack_kwargs: names = ( name for name in cfg if name not in cfg_kwargs and name not in self._exclude and isinstance(name, str) ) cfg_kwargs.update({name: cfg[name] for name in names}) wrapper = self._instantiation_wrapper or _identity func: Callable[P, R] = ( wrapper(self.func) # type: ignore if context is None else partial(context.run, wrapper(self.func)) # type: ignore ) return func( *(self.instantiate(x) if is_instantiable(x) else x for x in args_), **{ name: self.instantiate(val) if is_instantiable(val) else val for name, val in cfg_kwargs.items() }, **extra_kwargs, ) # type: ignore
[docs] def hydra_main( self, config_path: Optional[str] = _UNSPECIFIED_, config_name: Optional[str] = None, version_base: Optional[str] = _UNSPECIFIED_, ) -> Callable[[Any], Any]: """ Generates a Hydra-CLI for the wrapped function. Equivalent to `hydra.main(zen(func), [...])()` Parameters ---------- config_path : Optional[str] The config path, an absolute path to a directory or a directory relative to the declaring python file. If `config_path` is not specified no directory is added to the config search path. Specifying `config_path` via `Zen.hydra_main` is only supported for Hydra 1.3.0+. config_name : Optional[str] The name of the config (usually the file name without the .yaml extension) version_base : Optional[str] There are three classes of values that the version_base parameter supports, given new and existing users greater control of the default behaviors to use. - If the version_base parameter is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred. - If the version_base parameter is None, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply config_path=None and hydra.job.chdir=False. - If the version_base parameter is an explicit version string like "1.1", then the defaults appropriate to that version are used. Returns ------- hydra_main : Callable[[Any], Any] Equivalent to `hydra.main(zen(func), [...])()` """ kw = dict(config_name=config_name) # For relative config paths, Hydra looks in the directory relative to the file # in which the task function is defined. Unfortunately, it is only able to # follow wrappers starting in Hydra 1.3.0. Thus `Zen.hydra_main` cannot # handle string config_path entries until Hydra 1.3.0 if (config_path is _UNSPECIFIED_ and HYDRA_VERSION < Version(1, 2, 0)) or ( ( isinstance(config_path, str) or (config_path is _UNSPECIFIED_ and version_base == "1.1") ) and HYDRA_VERSION < Version(1, 3, 0) ): # pragma: no cover warnings.warn( "Specifying config_path via hydra_zen.zen(...).hydra_main " "is only supported for Hydra 1.3.0+" ) if Version(1, 3, 0) <= HYDRA_VERSION and isinstance(config_path, str): # Here we create an on-the-fly wrapper so that Hydra can trace # back through the wrapper to the original task function # We could give `Zen` as `__wrapped__` attr, but this messes with # things like `inspect.signature`. # # A downside of this is that `wrapper` is not pickle-able. @wraps(self.func) def wrapper(cfg: Any): return self(cfg) target = wrapper else: target = self if config_path is not _UNSPECIFIED_: kw["config_path"] = config_path if version_base is not _UNSPECIFIED_: # pragma: no cover kw["version_base"] = version_base return hydra.main(**kw)(target)()
@overload def zen( __func: Callable[P, R], *, unpack_kwargs: bool = ..., pre_call: PreCall = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., resolve_pre_call: bool = ..., run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., instantiation_wrapper: Optional[Callable[[F2], F2]] = ..., ) -> Zen[P, R]: ... @overload def zen( __func: Literal[None] = None, *, unpack_kwargs: bool = ..., pre_call: PreCall = ..., resolve_pre_call: bool = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., instantiation_wrapper: Optional[Callable[[F2], F2]] = ..., ) -> Callable[[Callable[P2, R2]], Zen[P2, R2]]: ...
[docs] def zen( __func: Optional[Callable[P, R]] = None, *, unpack_kwargs: bool = False, pre_call: PreCall = None, exclude: Optional[Union[str, Iterable[str]]] = None, resolve_pre_call: bool = True, run_in_context: bool = False, ZenWrapper: Type[Zen[Any, Any]] = Zen, instantiation_wrapper: Optional[Callable[[F2], F2]] = None, ) -> Union[Callable[[Callable[P2, R2]], Zen[P2, R2]], Zen[P, R]]: r"""zen(func, /, pre_call, ZenWrapper) A wrapper that returns a function that will auto-extract, resolve, and instantiate fields from an input config based on the wrapped function's signature. .. code-block:: pycon >>> fn = lambda x, y, z : x+y+z >>> wrapped_fn = zen(fn) >>> cfg = dict(x=1, y=builds(int, 4), z="${y}", unused=100) >>> wrapped_fn(cfg) # x=1, y=4, z=4 9 The main purpose of `zen` is to enable a user to write/use Hydra-agnostic functions as the task functions for their Hydra app. See "Notes" for more details. Parameters ---------- func : Callable[Sig, R], positional-only The function being wrapped. unpack_kwargs: bool, optional (default=False) If `True` a `**kwargs` field in the wrapped function's signature will be populated by all of the input config entries that are not specified by the rest of the signature (and that are not specified by the `exclude` argument). pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] One or more functions that will be called with the input config prior to the wrapped function. An iterable of pre-call functions are called from left (low-index) to right (high-index). This is useful, e.g., for seeding a RNG prior to the instantiation phase that is triggered when calling the wrapped function. resolve_pre_call : bool, (default=True) If `True`, the config passed to the zen-wrapped function has its interpolated fields resolved prior to being passed to any pre-call functions. Otherwise, the interpolation occurs after the pre-call functions are called. exclude : Optional[str | Iterable[str]] Specifies one or more parameter names in the function's signature that will not be extracted from input configs by the zen-wrapped function. A single string of comma-separated names can be specified. run_in_context : bool, optional (default=False) If `True`, the zen-wrapped function - and the `pre_call` function, if specified - is run in a copied :py:class:`contextvars.Context`; i.e. changes made to any :py:class:`contextvars.ContextVar` will be isolated to that call of the wrapped function. `run_in_context` is not supported for async functions. ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen) If specified, a subclass of `Zen` that customizes the behavior of the wrapper. instantiation_wrapper : Optional[Callable[[F2], F2]], optional (default=None) If specified, a function that wraps the task function and all instantiation-targets before they are called. This can be used to introduce a layer of validation or logging to all instantiation calls in your application. Returns ------- wrapped : Zen[Sig, R] A callable with signature `(conf: ConfigLike, \\) -> R` The wrapped function is an instance of `hydra_zen.wrapper.Zen` and accepts a single Hydra config (a dataclass, dictionary, or omegaconf container). The parameters of the wrapped function's signature determine the fields that are extracted from the config; only those fields that are accessed will be resolved and instantiated. See Also -------- hydra_zen.wrapper.Zen : Implements the wrapping logic that is exposed by `hydra_zen.zen`. Notes ----- The following pseudo code conveys the core functionality of `zen`: .. code-block:: python from hydra_zen import instantiate as inst def zen(func): sig = get_signature(func) def wrapped(cfg): cfg = resolve_interpolated_fields(cfg) kwargs = {p: inst(getattr(cfg, p)) for p in sig} return func(**kwargs) return wrapped The presence of a parameter named "zen_cfg" in the wrapped function's signature will cause `zen` to pass the full, resolved config to that field. This specific parameter name can be overridden via `Zen.CFG_NAME`. Specifying `config_path` via `Zen.hydra_main` is only supported for Hydra 1.3.0+. Examples -------- **Basic Usage** >>> from hydra_zen import zen, make_config, builds >>> def f(x, y): return x + y >>> zen_f = zen(f) The wrapped function – `zen_f` – accepts a single argument: a Hydra-compatible config that has the attributes "x" and "y": >>> zen_f zen[f(x, y)](cfg, /) "Configs" – dataclasses, dictionaries, and omegaconf containers – are acceptable inputs to zen-wrapped functions. Interpolated fields will be resolved and sub-configs will be instantiated. Excess fields in the config are unused. >>> zen_f(make_config(x=1, y=2, z=999)) # z is not used 3 >>> zen_f(dict(x=2, y="${x}")) # y will resolve to 2 4 >>> zen_f(dict(x=2, y=builds(int, 10))) # y will instantiate to 10 12 The wrapped function can be accessed directly >>> zen_f.func <function __main__.f(x, y)> >>> zen_f.func(-1, 1) 0 `zen` is compatible with partial'd functions. >>> from functools import partial >>> pf = partial(lambda x, y: x + y, x=10) >>> zpf = zen(pf) >>> zpf(dict(y=1)) 11 >>> zpf(dict(x='${y}', y=1)) 2 One can specify `exclude` to prevent particular variables from being extracted from a config: >>> def g(x=1, y=2): return (x, y) >>> cfg = {"x": -10, "y": -20} >>> zen(g)(cfg) # extracts x & y from config to call f (-10, -20) >>> zen(g, exclude="x")(cfg) # extracts y from config to call f(x=1, ...) (1, -20) >>> zen(g, exclude="x,y")(cfg) # defers to f's defaults (1, 2) Populating a `**kwargs` field via `unpack_kwargs=True`: >>> def h(a, **kw): ... return a, kw >>> cfg = dict(a=1, b=22) >>> zen(h, unpack_kwargs=False)(cfg) (1, {}) >>> zen(h, unpack_kwargs=True)(cfg) (1, {'b': 22}) **Passing Through The Full Input Config** Some task functions require complete access to the full config to gain access to sub-configs. One can specify the field named `zen_config` in their task function's signature to signal `zen` that it should pass the full config to that parameter . >>> def zf(x: int, zen_cfg): ... return x, zen_cfg >>> zen(zf)(dict(x=1, y="${x}", foo="bar")) (1, {'x': 1, 'y': 1, 'foo': 'bar'}) **Including a pre-call function** Given that a zen-wrapped function will automatically extract and instantiate config fields upon being called, it can be necessary to include a pre-call step that occurs prior to any instantiation. `zen` can be passed one or more pre-call functions that will be called with the input config as a precursor to calling the decorated function. Consider the following scenario where the instantiating the input config involves drawing a random value, which we want to be made deterministic with a configurable seed. We will use a pre-call function to seed the RNG prior to the instantiation. >>> import random >>> from hydra_zen import builds, zen >>> >>> def func(rand_val: int): return rand_val >>> >>> cfg = dict( ... seed=0, ... rand_val=builds(random.randint, 0, 10), ... ) >>> wrapped = zen(func, pre_call=lambda cfg: random.seed(cfg.seed)) >>> def f1(rand_val: int): ... return rand_val >>> zf1 = zen(pre_call=lambda cfg: random.seed(cfg.seed))(f1) >>> [zf1(cfg) for _ in range(10)] [6, 6, 6, 6, 6, 6, 6, 6, 6, 6] **Using `zen` instead of `@hydra.main`** The object returned by zen provides a convenience method – `Zen.hydra_main` – to generate a CLI for a zen-wrapped task function: .. code-block:: python # example.py from hydra_zen import zen, store @store(name="my_app") def task(x: int, y: int): print(x + y) if __name__ == "__main__": store.add_to_hydra_store() zen(task).hydra_main(config_name="my_app", config_path=None, version_base="1.2") .. code-block:: console $ python example.py x=1 y=2 3 **Validating input configs** An input config can be validated against a zen-wrapped function – without calling said function – via the `.validate` method. >>> def f2(x: int): ... >>> zen_f = zen(f2) >>> zen_f.validate({"x": 1}) # OK >>> zen_f.validate({"y": 1}) # Missing x HydraZenValidationError: `cfg` is missing the following fields: x Validation propagates through zen-wrapped pre-call functions: >>> zen_f2 = zen(f2, pre_call=zen(lambda seed: None)) >>> zen_f2.validate({"x": 1, "seed": 10}) # OK >>> zen_f2.validate({"x": 1}) # Missing seed as required by pre-call HydraZenValidationError: `cfg` is missing the following fields: seed """ if __func is not None: return cast( Zen[P, R], ZenWrapper( __func, pre_call=pre_call, exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, run_in_context=run_in_context, instantiation_wrapper=instantiation_wrapper, ), ) def wrap(f: Callable[P2, R2]) -> Zen[P2, R2]: out = cast( Zen[P2, R2], ZenWrapper( f, pre_call=pre_call, exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, run_in_context=run_in_context, instantiation_wrapper=instantiation_wrapper, ), ) return out return wrap
[docs] def default_to_config( target: Union[ Callable[..., Any], DataClass_, List[Any], Dict[Any, Any], ListConfig, DictConfig, ], CustomBuildsFn: Type["BuildsFn[Any]"] = DefaultBuilds, **kw: Any, ) -> Union[DataClass_, Type[DataClass_], ListConfig, DictConfig]: """Creates a config that describes `target`. This function is designed to selectively apply `hydra_zen.builds` or `hydra_zen.just` in a way that permits maximum compatibility with common inputs to `hydra_zen.ZenStore`. It behavior can be summarized based on the type of `target` - OmegaConf containers and dataclass *instances* are returned unchanged - A dataclass type is processed as `builds(target, **kw, populate_full_signature=True, builds_bases=(target,))` - Lists and dictionaries are processed by `hydra_zen.just` - All other inputs are processed as `builds(target, **kw, populate_full_signature=True)` Parameters ---------- target : Callable[..., Any] | DataClass | Type[DataClass] | list | dict CustomBuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds) Provides the config-creation functions (`builds`, `just`) used by this function. **kw : Any Keyword arguments to be passed to `builds`. Returns ------- target_config : DataClass | Type[DataClass] | list | dict Examples -------- Lists and dictionaries >>> from hydra_zen.wrapper import default_to_config >>> default_to_config([1, {"z": 2+2j}]) [1, {'z': ConfigComplex(real=2.0, imag=2.0, _target_='builtins.complex')}] Dataclass types >>> from dataclasses import dataclass >>> >>> @dataclass ... class A: ... x: int ... y: int >>> Builds_A = default_to_config(A, y=22) >>> Builds_A(x=1) Builds_A(x=1, y=22, _target_='__main__.A') >>> issubclass(Builds_A, A) True A function >>> from hydra_zen import to_yaml >>> def func(x: int, y: int): ... >>> print(to_yaml(default_to_config(func))) _target_: __main__.func x: ??? 'y': ??? """ kw = kw.copy() if is_dataclass(target): if isinstance(target, type): if issubclass(target, HydraConf): # don't auto-config HydraConf return target if not kw and CustomBuildsFn._get_obj_path(target).startswith("types."): # type: ignore # handles dataclasses returned by make_config() return target kw.setdefault("populate_full_signature", True) kw.setdefault("builds_bases", (target,)) return CustomBuildsFn.builds(target, **kw) if kw: raise ValueError( "store(<dataclass-instance>, [...]) does not support specifying " "keyword arguments" ) return target elif isinstance(target, (dict, list)): # TODO: convert to OmegaConf containers? return CustomBuildsFn.just(target) elif isinstance(target, (DictConfig, ListConfig)): return target else: t = cast(Callable[..., Any], target) kw.setdefault("populate_full_signature", True) return cast(Type[DataClass_], CustomBuildsFn.builds(t, **kw))
class _HasName(Protocol): __name__: str # TODO: Should we automatically snake-case? def get_name(target: _HasName) -> str: name = getattr(target, "__name__", None) if not isinstance(name, str): raise TypeError( f"Cannot infer config store entry name for {target}. It does not have a " f"`__name__` attribute. Please manually specify `store({target}, " f"name=<some name>, [...])`" ) return name class _StoreCallSig(TypedDict): """Arguments for ZenStore.__call__ This default dict enables us to easily update/merge the default arguments for a specific ZenStore instance, in support of self-partialing behavior.""" name: Union[NodeName, Callable[[Any], NodeName]] group: Union[GroupName, Callable[[Any], GroupName]] package: Optional[Union[str, Callable[[Any], str]]] provider: Optional[str] __kw: Dict[str, Any] # kwargs passed to to_config to_config: Callable[[Any], Any] # TODO: make frozen dict defaults: Final = _StoreCallSig( name=get_name, group=None, package=None, provider=None, to_config=default_to_config, __kw={}, ) _DEFAULT_KEYS: Final[FrozenSet[str]] = frozenset( _StoreCallSig.__required_keys__ - {"__kw"} ) class _Deferred: __slots__ = ("to_config", "target", "kw") def __init__( self, to_config: Callable[[F], Node], target: F, kw: Dict[str, Any] ) -> None: self.to_config = to_config self.target = target self.kw = kw def __call__(self) -> Any: return self.to_config(self.target, **self.kw) def _resolve_node(entry: StoreEntry, copy: bool) -> StoreEntry: """Given an entry, updates the entry so that its node is not deferred, and returns the entry. This function is a passthrough for an entry whose node is not deferred""" item = entry["node"] if isinstance(item, _Deferred): entry["node"] = item() if copy: entry = entry.copy() return entry
[docs] class ZenStore: """An abstraction over Hydra's store, for creating multiple, isolated config stores. Whereas Hydra exposes a single global config store that provides no warnings when store entries are overwritted, `ZenStore` instances are isolated, do not populate the global store unless instructed to, and they protect users from unwittingly overwriting store entries. Notes ----- `hydra_zen.store` is available as a pre-instantiated globally-available store, which is initialized as: .. code-block:: python store = ZenStore( name="zen_store", deferred_to_config=True, deferred_hydra_store=True, ) Internally, each `ZenStore` instance holds a mapping of:: tuple[group, name] -> {node: Dataclass | omegaconf.Container, name: str, group: str, package: Optional[str], provider: Optional[str]} **Auto-config capabilities** `ZenStore` is also designed to consolidate the config-creation and storage process; it can be used to automatically apply config-creation a function (e.g., `~hydra_zen.builds`) to a target in order to auto-generate a config for the target, which is then stored. .. tab-set:: .. tab-item:: Via auto-config .. code-block:: python from hydra_zen import store def func(x, y): ... store(func, x=2, y=3) .. tab-item:: Via manually-specified config .. code-block:: python from hydra_zen import builds, store def func(x, y): ... store(builds(func, x=2, y=3, populate_full_signature=True), name="func") It can also be used to decorate config-targets and dataclasses, enabling "inline" config creation and storage patterns. These auto config-creation capabilities are designed to be deferred until a config is actually accessed by users or added to Hydra's global config store. This enables store-decorator patterns to be used within library code without slowing down import times. **Self-partialing patterns** A `ZenStore` instance can be called repeatedly - without a config target - with different options to incrementally change the store's default configurations. E.g. the following are effectively equivalent .. tab-set:: .. tab-item:: Self-partialing pattern .. code-block:: python from hydra_zen import store book_store = store(group="books") romance_store = book_store(provider="genre: romance") fantasy_store = book_store(provider="genre: fantasy") romance_store({"title": "heartfelt"}) romance_store({"title": "lustfully longingness"}) fantasy_store({"title": "elvish cookbook"}) fantasy_store({"title": "dwarves can't jump"}) .. tab-item:: Manual pattern .. code-block:: python from hydra_zen import store store( {"title": "heartfelt"}, group="book", provider="genre: romance", ) store( {"title": "lustfully longingness"}, group="book", provider="genre: romance", ) store( {"title": "elvish cookbook"}, group="book", provider="genre: fantasy", ) store( {"title": "dwarves can't jump"}, group="book", provider="genre: fantasy", ) **Configuring Hydra itself** Special support is provided for overriding Hydra's configuration; the name and group of the store entry is inferred to be 'config' and 'hydra', respectively, when an instance/subclass of `HydraConf` is being stored. E.g., specifying .. code-block:: python from hydra.conf import HydraConf, JobConf from hydra_zen import store store(HydraConf(job=JobConf(chdir=True))) is equivalent to writing the following manually .. code-block:: python store(HydraConf(job=JobConf(chdir=True)), name="config", group="hydra", provider="hydra_zen") Additionally, overwriting the store entry for `HydraConf` will not raise an error even if `ZenStore(overwrite_ok=False)` is specified. Examples -------- (Some helpful boilerplate code for these examples) >>> from hydra_zen import to_yaml, store, ZenStore >>> def pyaml(x): ... # for pretty printing configs ... print(to_yaml(x)) **Basic usage** Let's add a config to hydra-zen's pre-instantiated `ZenStore` instance. Each store entry must have an associated name. Optionally, a group, package, and/or provider may be specified for the entry as well. >>> config1 = {'name': 'Roger', 'age': 24} >>> config2 = {'name': 'Rita', 'age': 27} >>> _ = store(config1, name="roger", group="profiles") >>> _ = store(config2, name="rita", group="profiles") >>> store zen_store {'profiles': ['roger', 'rita']} A store's entries are keyed by their `(group, name)` pairs (the default group is `None`). >>> store["profiles", "roger"] # (group, name) -> config node {'name': 'Roger', age: 24} By default, the stored config(s) will be "enqueued" for addition to Hydra's config store. The method `.add_to_hydra_store()` must be called to add the enqueued configs to Hydra's central store. >>> store.has_enqueued() True >>> store.add_to_hydra_store() # adds all enqueued entries to Hydra's global store >>> store.has_enqueued() False By default, attempting to overwrite an entry will result in an error. >>> store({}, name="rita", group="profiles") # same name and group as above ValueError: (name=rita group=profiles): Hydra config store entry already exists. Specify `overwrite_ok=True` to enable replacing config store entries We can create a distinct store that has an independent internal repository of configs. >>> new_store = ZenStore("new_store") >>> _ = new_store([1, 2, 3], name="backbone") >>> store zen_store {'profiles': ['roger', 'rita']} >>> new_store new_store {None: ['backbone']} .. _store-autoconf: **Auto-config capabilities** The input to a store is processed by the store's `to_config` function prior to creating the stored config node. This defaults to `hydra_zen.wrapper.default_to_config`, which applies `hydra_zen.builds` or `hydra_zen.just` to inputs based on their types. For instance, consider the following function: >>> def sum_it(a: int, b: int): return a + b We can pass `sum_it` directly to our store to leverage auto-config and auto-naming capabilities. Here, `builds(sum_it, a=1, b=2)` will be called under the hood by `new_store` to create the config for `sum_it`. >>> store2 = ZenStore() >>> _ = store2(sum_it, a=1, b=2) # entry name defaults to `sum_it.__name__` >>> config = store2[None, "sum_it"] >>> pyaml(config) _target_: __main__.sum_it a: 1 b: 2 Refer to `hydra_zen.wrapper.default_to_config` for more details about the default auto-config behaviors of `ZenStore`. **Support for decorator patterns** `ZenStore.__call__` is a pass-through and can be used as a decorator. Let's add two store entries for `func` by decorating it. >>> store = ZenStore() >>> @store(a=1, b=22, name="func1") ... @store(a=-10, name="func2") ... def func(a: int, b: int): ... return a - b Each application of `@store` utilizes the store's auto-config capability to create and store a config inline. I.e. the above snippet is equivalent to >>> from hydra_zen import builds >>> >>> store(builds(func, a=1, b=22), name="func1") >>> store(builds(func, a=-10, ... populate_full_signature=True ... ), ... name="func2", ... ) >>> func(10, 3) # the decorated function is left unchanged 7 >>> pyaml(store[None, "func1"]) _target_: __main__.func a: 1 b: 22 >>> pyaml(store[None, "func2"]) _target_: __main__.func b: ??? a: -10 Note that, by default, the application of `to_config` via the store **is deferred until that entry is actually accessed**. This offsets the runtime cost of constructing configs for the decorated function so that it need not be paid until the config is actually accessed by the store. .. _self-partial: **Customizable store defaults via 'self-partialing' patterns** The default values for a store's `__call__` parameters – `group`, `to_config`, etc. – can easily be customized. Simply call the store with those new values and without specifying an object to be stored. This will return a "mirrored" store instance – with the same internal state as the original store – with updated defaults. For example, let's create a store where we want to store multiple configs under a `'math'` group and under a `'functools'` group, respectively. >>> import math >>> import functools >>> new_store = ZenStore() >>> math_store = new_store(group="math") # overwrites group default >>> tool_store = new_store(group="functools") # overwrites group default `math_store` and `tool_store` both share the same internal state as `new_store`, but have overwritten default values for the `group`. >>> math_store(math.floor) # equivalent to: `new_store(math.floor, group="math")` >>> math_store(math.ceil) >>> tool_store(functools.lru_cache) >>> tool_store(functools.wraps) See that `new_store` has entries under these corresponding groups: >>> new_store custom_store {'math': ['floor', 'ceil'], 'functools': ['lru_cache', 'wraps']} These "self-partialing" patterns can be chained indefinitely and can be used to set partial defaults for the `to_config` function. >>> profile_store = new_store(group="profile") >>> schemaless = profile_store(schema="<none>") >>> from dataclasses import dataclass >>> >>> @profile_store(name="admin", has_root=True) >>> @schemaless(name="test_admin", has_root=True) >>> @schemaless(name="test_user", has_root=False) >>> @dataclass >>> class Profile: >>> username: str >>> schema: str >>> has_root: bool >>> pyaml(new_store["profile", "admin"]) username: ??? schema: ??? has_root: true _target_: __main__.Profile >>> pyaml(new_store["profile", "test_admin"]) username: ??? schema: <none> has_root: true _target_: __main__.Profile **Manipulating and updating a store** A store can be copied, updated, and merged. Its entries can have their groups remapped, and individual entries can be deleted. See the docs for the corresponding methods for details and examples. """ __slots__ = ( "name", "_internal_repo", "_defaults", "_queue", "_deferred_to_config", "_deferred_store", "_overwrite_ok", "_warn_node_kwarg", )
[docs] def __init__( self, name: Optional[str] = None, *, deferred_to_config: bool = True, deferred_hydra_store: bool = True, overwrite_ok: bool = False, warn_node_kwarg: bool = True, ) -> None: """ Parameters ---------- name : Optional[str] The name for this store. deferred_to_config : bool, default=True If `True` (default), this store will a not apply `to_config` to the target until that specific entry is accessed by the store. deferred_hydra_store : bool, default=True If `True` (default), this store will not add entries to Hydra's global config store until `store.add_to_hydra_store` is called explicitly. overwrite_ok : bool, default=False If `False` (default), attempting to overwrite entries in this store and trying to use this store to overwrite entries in Hydra's global store will raise a `ValueError`. warn_node_kwarg: bool, default=True If `True` specifying a `node` kwarg in `ZenStore.__call__` will emit a warning. This helps to protect users from mistakenly self-partializing a store with `store(node=Config)` instead of actually storing the node with `store(Config)`. """ if not isinstance(deferred_to_config, bool): raise TypeError( f"deferred_to_config must be a bool, got {deferred_to_config}" ) if not isinstance(overwrite_ok, bool): raise TypeError(f"overwrite_ok must be a bool, got {overwrite_ok}") if not isinstance(deferred_hydra_store, bool): raise TypeError( f"deferred_hydra_store must be a bool, got {deferred_hydra_store}" ) self.name: str = "custom_store" if name is None else name # The following attributes are mirrored across store instances that are # created via the 'self-partialing' process self._internal_repo: Dict[Tuple[GroupName, NodeName], StoreEntry] = {} # Internal repo entries that have yet to be added to Hydra's config store self._queue: Set[Tuple[GroupName, NodeName]] = set() self._deferred_to_config = deferred_to_config self._deferred_store = deferred_hydra_store self._overwrite_ok = overwrite_ok # Contains the current default arguments for `self.__call__` self._defaults: _StoreCallSig = defaults.copy() self._warn_node_kwarg = warn_node_kwarg
def __repr__(self) -> str: # TODO: nicer repr? groups_contents: DefaultDict[Optional[str], List[str]] = defaultdict(list) for grp, name in self._internal_repo: groups_contents[grp].append(name) return f"{self.name}\n{repr(dict(groups_contents))}"
[docs] def __eq__(self, __o: object) -> bool: """Returns `True` if two stores share identical internal repos and queues. Examples -------- >>> from hydra_zen import ZenStore >>> store1 = ZenStore() >>> store2 = ZenStore() >>> store1_a = store1(group='a') >>> _ = store1_a(dict(x=1), name="foo") >>> store1 == store1_a True >>> store1 == store2 False """ if not isinstance(__o, ZenStore): return False return __o._internal_repo is self._internal_repo and __o._queue is self._queue
# TODO: support *to_config_pos_args @overload def __call__( self, __target: F, *, name: Union[NodeName, Callable[[F], NodeName]] = ..., group: Union[GroupName, Callable[[F], GroupName]] = ..., package: Optional[Union[str, Callable[[F], str]]] = ..., provider: Optional[str] = ..., to_config: Callable[[F], Node] = default_to_config, **to_config_kw: Any, ) -> F: ... @overload def __call__( self: Self, __target: Literal[None] = None, *, name: Union[NodeName, Callable[[Any], NodeName]] = ..., group: Union[GroupName, Callable[[Any], GroupName]] = ..., package: Optional[Union[str, Callable[[Any], str]]] = ..., provider: Optional[str] = ..., to_config: Callable[[Any], Node] = ..., **to_config_kw: Any, ) -> Self: ...
[docs] def __call__(self: Self, __target: Optional[F] = None, **kw: Any) -> Union[F, Self]: """__call__(target : Optional[T] = None, /, name: NodeName | Callable[[Any], NodeName]] = ..., group: GroupName | Callable[[T], GroupName]] = None, package: Optional[str | Callable[[T], str]]] | None], provider: Optional[str], to_config: Callable[[T], Node] = ..., **to_config_kw: Any) -> T | ZenStore Store a config or :ref:`customize the default values <self-partial>` of the store. Parameters ---------- obj : Optional[T] The object to be stored. This is a **positional-only** argument. If `obj` is not specified, then the provided arguments are used to create a mirrored store instance with updated default arguments. name : NodeName | Callable[[T], NodeName] The entry's name, or a callable that will be called as `(obj) -> entry-name`. The default is `lambda obj: obj.__name__`. Store entries are keyed off of `(group, name)`. group : Optional[GroupName | Callable[[T], GroupName]] The entry's group's name, or a callable that will be called as `(obj) -> entry-group`. The default is `None`. Subgroups can be specified using / within the group name. Store entries are keyed off of `(group, name)`. to_config : Callable[[T], Node] = default_to_config Called on `obj` to produce the entry's "node" (the config). Refer to `hydra_zen.wrapper.default_to_config` for the default behavior. Specify `lambda x: x` to have `obj` be stored directly as the entry's node. By default the call to `to_config` is deferred until the entry is actually accessed by the store. package : Optional[str | Callable[[Any], str]] The entry's package. Default is `None`. provider : Optional[str] An optional provider name for the entry. **to_config_kw : Any Additional arguments that will be passed to `to_config`. Returns ------- T | ZenStore If `obj` was specified, it is returned unchanged. Otherwise a new instance of `ZenStore` is return, which mirrors the internal state of this store and has updated default arguments. """ if __target is None: if self._warn_node_kwarg and "node" in kw: warnings.warn( "hydra-zen's store API does not use the `node` keyword. To store a " "config, specify it as a positional argument: `store(<config>)`." "\n\nIf the use of `node` was intentional, you can suppress this " "warning by using a store that is initialized via `ZenStore" "(warn_node_kwarg=False)." ) _s = type(self)( self.name, deferred_to_config=self._deferred_to_config, deferred_hydra_store=self._deferred_store, overwrite_ok=self._overwrite_ok, warn_node_kwarg=self._warn_node_kwarg, ) _s._defaults = self._defaults.copy() # Important: mirror internal state *by reference* to ensure `_s` and # `self` remain in sync _s._internal_repo = self._internal_repo _s._queue = self._queue new_defaults: _StoreCallSig = {k: kw[k] for k in _DEFAULT_KEYS if k in kw} # type: ignore new_defaults["__kw"] = { **_s._defaults["__kw"], **{k: kw[k] for k in set(kw) - _DEFAULT_KEYS}, } _s._defaults.update(new_defaults) return _s else: to_config = kw.get("to_config", self._defaults["to_config"]) name = kw.get("name", self._defaults["name"]) group = kw.get("group", self._defaults["group"]) package = kw.get("package", self._defaults["package"]) provider = kw.get("provider", self._defaults["provider"]) if ( isinstance(__target, HydraConf) or isinstance(__target, type) and issubclass(__target, HydraConf) ): # User is re-configuring Hydra's config; we provide "smart" defaults # for the entry's name, group, and package if "name" not in kw and "group" not in kw: # pragma: no branch # only apply when neither name nor group are specified name = "config" group = "hydra" if "provider" not in kw: # pragma: no branch provider = "hydra_zen" _name: NodeName = name(__target) if callable(name) else name # type: ignore if not isinstance(_name, str): raise TypeError(f"`name` must be a string, got {_name}") del name _group: GroupName = group(__target) if callable(group) else group # type: ignore if _group is not None and not isinstance(_group, str): raise TypeError(f"`group` must be a string or None, got {_group}") del group _pkg = package(__target) if callable(package) else package if _pkg is not None and not isinstance(_pkg, str): raise TypeError(f"`package` must be a string or None, got {_pkg}") del package merged_kw = { **self._defaults["__kw"], **{k: kw[k] for k in set(kw) - _DEFAULT_KEYS}, } if self._deferred_to_config: node = _Deferred(to_config, __target, merged_kw) else: node = to_config(__target, **merged_kw) entry = StoreEntry( name=_name, group=_group, package=_pkg, provider=provider, node=node, ) self._set_entry(entry, overwrite=self._overwrite_ok) return cast(Union[F, Self], __target)
[docs] def copy(self: Self, store_name: Optional[str] = None) -> Self: """Returns a copy of the store with the same overridden defaults. Parameters ---------- store_name : str | None, optional (default=None) Returns ------- ZenStore Examples -------- >>> from hydra_zen import ZenStore >>> s1 = ZenStore()(group="G") >>> s1({}, name="a") >>> s2 = s1.copy() >>> s2({}, name="b") >>> s1 s1 {'G': ['a']} >>> s2 s1_copy {'G': ['a', 'b']} """ cp = deepcopy(self) cp.name = store_name if store_name is not None else self.name + "_copy" return cp
[docs] def copy_with_mapped_groups( self: Self, old_group_to_new_group: Union[ Mapping[GroupName, GroupName], Callable[[GroupName], GroupName] ], *, store_name: Optional[str] = None, overwrite_ok: Optional[bool] = None, ) -> Self: """Create a copy of a store whose entries' groups have been updated according to the provided mapping. Parameters ---------- old_group_to_new_group : Mapping[GroupName, GroupName] | Callable[[GroupName], GroupName] A mapping or callable that transforms an old group name to a new one. Groups in the store that are not included in the mapping are unaffected. A `GroupName` is `str | None`. store_name : Optional[None] If specified, the name of the new store. overwrite_ok : Optional[bool]: If specified, determines if the mapping can overwrite existing store entries. Otherwise, defers to `ZenStore(overwrite_ok)`. Returns ------- new_store A copy of `self` with remapped groups. Examples -------- >>> from hydra_zen import ZenStore Creating an initial store >>> s1 = ZenStore("s1") >>> s1({}, group=None, name="a") >>> s1({}, group="A/1", name="b") >>> s1({}, group="A/2", name="c") >>> s1 s1 {None: ['a'], 'A/1': ['b'], 'A/2': ['c']} Replacing group "A/1" with "B", via a mapping >>> s2 = s1.copy_with_mapped_groups({"A/1": "B"}, store_name="s2") >>> s2 s2 {None: ['a'], 'A/2': ['c'], 'B': ['b']} Placing all entries under group "A/" within a new inner group "p", via a function >>> s3 = s1.copy_with_mapped_groups( ... lambda g: g + "/p" if g and g.startswith("A/") else g, store_name="s3" ... ) >>> s3 s3 {None: ['a'], 'A/1/p': ['b'], 'A/2/p': ['c']} """ overwrite = overwrite_ok if overwrite_ok is not None else self._overwrite_ok map_fn: Callable[[GroupName], GroupName] = ( (lambda x: old_group_to_new_group.get(x, x)) if isinstance(old_group_to_new_group, Mapping) else old_group_to_new_group ) copy = self.copy(store_name) for (group, name), entry in tuple(copy._internal_repo.items()): new_group = map_fn(group) if new_group != group: del copy[group, name] entry["group"] = new_group copy._set_entry(entry, overwrite=overwrite) return copy
@property def groups(self) -> Sequence[GroupName]: """Returns a sorted list of the groups registered with this store""" set_: Set[GroupName] = set(group for group, _ in self._internal_repo) if None in set_: set_.remove(None) no_none = cast(Set[str], set_) return [None] + sorted(no_none) else: no_none = cast(Set[str], set_) return sorted(no_none)
[docs] def enqueue_all(self) -> None: """Add all of the store's entries to the queue to be added to hydra's store. Examples -------- >>> from hydra_zen import ZenStore >>> store = ZenStore(deferred_hydra_store=True) >>> store({"a": 1}, name) >>> store.has_enqueued() True >>> store.add_to_hydra_store() >>> store.has_enqueued() False >>> store.enqueue_all() >>> store.has_enqueued() True """ self._queue.update(self._internal_repo.keys())
[docs] def has_enqueued(self) -> bool: """`True` if this store has entries that have not yet been added to Hydra's config store. Returns ------- bool Examples -------- >>> from hydra_zen import ZenStore >>> store = ZenStore(deferred_hydra_store=True) >>> store.has_enqueued() False >>> store({"a": 1}, name) >>> store.has_enqueued() True >>> store.add_to_hydra_store() >>> store.has_enqueued() False """ return bool(self._queue)
def __bool__(self) -> bool: """`True` if entries have been added to this store, regardless of whether or not they have been added to Hydra's config store""" return bool(self._internal_repo) def __len__(self) -> int: return len(self._internal_repo)
[docs] def update(self, __other: "ZenStore") -> None: """Updates the store inplace with redundant entries being overwritten. Can also be applied via the `|=` in-place operator. Examples -------- >>> from hydra_zen import ZenStore >>> def f(): ... >>> def g(): ... >>> s1 = ZenStore("s1") >>> s2 = ZenStore("s2") >>> s1(f) # store f in s1 >>> s2(g) # store g in s2 >>> s1.update(s2) >>> s1 # s1 now has entries for both f and g s1 {None: ['f', 'g']} Alternatively, the `|=` operator can be used to update a store inplace. >>> s3 = ZenStore("s3") >>> s3 |= s2 >>> s3 s3 {None: ['g']} """ if __other == self: return self._internal_repo.update(deepcopy(__other._internal_repo)) self._queue.update(__other._queue) if not self._deferred_store: self.add_to_hydra_store() return
[docs] def merge( self: Self, __other: "ZenStore", store_name: Optional[str] = None ) -> Self: """Create a new store by merging two stores. The new store's default settings will reflect those of `self` in `self.merge(other)`. This can also be applied via the `|` operator. Examples -------- >>> from hydra_zen import ZenStore >>> def f(): ... >>> def g(): ... >>> s1 = ZenStore("s1") >>> s2 = ZenStore("s2") >>> s1(f) # store f in s1 >>> s2(g) # store g in s2 >>> s3 = s1.merge(s2) >>> s3 s1_copy {None: ['f', 'g']} Alternatively, the `|` operator can be used to merge stores. >>> s4 = s1 | s2 >>> s4 s1_copy {None: ['f', 'g']} """ cp = self.copy(store_name) cp.update(__other) return cp
def __or__(self: Self, other: "ZenStore") -> Self: return self.merge(other) def __ior__(self: Self, other: "ZenStore") -> Self: self.update(other) return self @overload def __getitem__(self, key: Tuple[GroupName, NodeName]) -> Node: ... @overload def __getitem__(self, key: GroupName) -> Dict[Tuple[GroupName, NodeName], Node]: ...
[docs] def __getitem__(self, key: Union[GroupName, Tuple[GroupName, NodeName]]) -> Node: """Access a entry's config node by specifying `(group, name)`. Or, access a mapping of `(group, name) -> node` for all nodes in a specified group, including nodes within subgroups. See Also -------- ZenStore.get_entry Examples -------- >>> from hydra_zen import store >>> store(dict(x=1), name="a", group="fruit") >>> store(dict(x=2), name="b", group="fruit/apple") >>> store(dict(x=3), name="c", group="fruit/apple") >>> store(dict(x=4), name="d", group="fruit/orange") >>> store(dict(x=5), name="e", group="veggie") Accessing an individual entry's config node. >>> store["fruit/apple", "b"] {'x': 2} Accessing all config nodes under the "fruit/apple" group >>> store["fruit/apple"] {('fruit/apple', 'b'): {'x': 2}, ('fruit/apple', 'c'): {'x': 3}} Accessing all config nodes under the "fruit" group >>> store["fruit"] {('fruit', 'a'): {'x': 1}, ('fruit/apple', 'b'): {'x': 2}, ('fruit/apple', 'c'): {'x': 3}, ('fruit/orange', 'd'): {'x': 4}} """ # store[group] -> # {(group, name): node1, (group, name2): node2, (group/subgroup, name3): node3} # # store[group, name] -> node if isinstance(key, str) or key is None: key_not_none = key is not None key_w_ender = key + "/" if key is not None else "<ZEN_NEVER>" return { (group, name): _resolve_node(entry, copy=False)["node"] for (group, name), entry in self._internal_repo.items() if group == key or ( key_not_none and group is not None and group.startswith(key_w_ender) ) } return _resolve_node(self._internal_repo[key], copy=False)["node"]
def __delitem__(self, key: Tuple[GroupName, NodeName]) -> None: del self._internal_repo[key] self._queue.discard(key)
[docs] def delete_entry(self, group: GroupName, name: NodeName) -> None: del self[group, name]
[docs] def get_entry(self, group: GroupName, name: NodeName) -> StoreEntry: """Access a store entry, which is a mapping that specifies the entry's name, group, package, provider, and node. Parameters ---------- group : str | None name : str Returns ------- dict - name: NodeName - group: GroupName - package: Optional[str] - provider: Optional[str] - node: ConfigType Notes ----- Mutating the returned mapping will not affect the store's internal entry. Mutating a node in the returned entry may have unintended consequences and is not advised. Examples -------- >>> from hydra_zen import store, ZenStore >>> store(dict(x=1), name="a", group="fruit") >>> store.get_entry("fruit", "a") {'name': 'a', 'group': 'fruit', 'package': None, 'provider': None, 'node': {'x': 1}} """ return _resolve_node(self._internal_repo[(group, name)], copy=True)
def _set_entry(self, __entry: StoreEntry, overwrite: bool) -> None: _group = __entry["group"] _name = __entry["name"] if not overwrite and (_group, _name) in self._internal_repo: raise ValueError( f"(name={__entry['name']} group={__entry['group']}): " f"Store entry already exists. Use a store initialized " f"with `ZenStore(overwrite_ok=True)` to overwrite config store " f"entries." ) self._internal_repo[_group, _name] = __entry self._queue.add((_group, _name)) if not self._deferred_store: self.add_to_hydra_store() def __contains__(self, key: Union[GroupName, Tuple[GroupName, NodeName]]) -> bool: """Checks if group or (group, node-name) exists in zen-store.""" if key is None: return any(k[0] is None for k in self._internal_repo) # pragma: no branch elif isinstance(key, str): key_w_end: str = key + "/" return any( key == group or group.startswith(key_w_end) for group, _ in self._internal_repo if group is not None ) return key in self._internal_repo
[docs] def __iter__(self) -> Generator[StoreEntry, None, None]: """Yields all entries in this store. Notes ----- Mutating the returned mappings will not affect the store's internal entries. Mutating a node in an entry may have unintended consequences and is not advised. Examples -------- >>> from hydra_zen import store >>> store(dict(x=1), name="a", group="fruit") >>> store(dict(x=2), name="b", group="fruit/orange") >>> store(dict(x=3), name="c", group="veggie") >>> list(store) [{'name': 'a', 'group': 'fruit', 'package': None, 'provider': None, 'node': {'x': 1}}, {'name': 'b', 'group': 'fruit/orange', 'package': None, 'provider': None, 'node': {'x': 2}}, {'name': 'c', 'group': 'veggie', 'package': None, 'provider': None, 'node': {'x': 3}}] """ yield from (_resolve_node(v, copy=True) for v in self._internal_repo.values())
[docs] def add_to_hydra_store(self, overwrite_ok: Optional[bool] = None) -> None: """Adds all of this store's enqueued entries to Hydra's global config store. This method need not be called for a store initialized as `ZenStore(deferred_hydra_store=False)`. Parameters ---------- overwrite_ok : Optional[bool] If `False`, this method raises `ValueError` if an entry in Hydra's config store will be overwritten. Defaults to the value of `overwrite_ok` specified when initializing this store. Examples -------- >>> from hydra_zen import ZenStore >>> store1 = ZenStore() >>> store2 = ZenStore() >>> store1({'a': 1}, name="x") >>> store1.add_to_hydra_store() >>> store2({'a': 2}, name="x") >>> store2.add_to_hydra_store() ValueError: (name=x group=None): Hydra config store entry already exists. Specify `overwrite_ok=True` to enable replacing config store entries >>> store2.add_to_hydra_store(overwrite_ok=True) # successfully overwrites entry """ _store = ConfigStore.instance().store for key in tuple(self._queue): entry = _resolve_node(self._internal_repo[key], copy=False) if ( ( overwrite_ok is False or (overwrite_ok is None and not self._overwrite_ok) ) and self._exists_in_hydra_store( name=entry["name"], group=entry["group"] ) # It is okay if we are overwriting Hydra's default store and not ( (entry["name"], entry["group"]) == ("config", "hydra") and ConfigStore.instance().repo["hydra"]["config.yaml"].provider == "hydra" ) ): raise ValueError( f"(name={entry['name']} group={entry['group']}): " f"Hydra config store entry already exists. Specify " f"`overwrite_ok=True` to enable replacing config store entries" ) _store(**entry) self._queue.discard(key)
def _exists_in_hydra_store( self, *, name: NodeName, group: GroupName, hydra_store: ConfigStore = ConfigStore().instance(), ) -> bool: repo = hydra_store.repo if group is not None: for group_name in group.split("/"): repo = repo.get(group_name) if repo is None: return False return name + ".yaml" in repo
store: ZenStore = ZenStore( name="zen_store", deferred_to_config=True, deferred_hydra_store=True, )