hydra_zen.zen#

hydra_zen.zen(func, /, pre_call, ZenWrapper)[source]#

A wrapper that returns a function that will auto-extract, resolve, and instantiate fields from an input config based on the wrapped function’s signature.

>>> fn = lambda x, y, z : x+y+z
>>> wrapped_fn = zen(fn)

>>> cfg = dict(x=1, y=builds(int, 4), z="${y}", unused=100)
>>> wrapped_fn(cfg)  # x=1, y=4, z=4
9

The main purpose of zen is to enable a user to write/use Hydra-agnostic functions as the task functions for their Hydra app. See “Notes” for more details.

Parameters:
funcCallable[Sig, R], positional-only

The function being wrapped.

unpack_kwargs: bool, optional (default=False)

If True a **kwargs field in the wrapped function’s signature will be populated by all of the input config entries that are not specified by the rest of the signature (and that are not specified by the exclude argument).

pre_callOptional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]]

One or more functions that will be called with the input config prior to the wrapped function. An iterable of pre-call functions are called from left (low-index) to right (high-index).

This is useful, e.g., for seeding a RNG prior to the instantiation phase that is triggered when calling the wrapped function.

resolve_pre_callbool, (default=True)

If True, the config passed to the zen-wrapped function has its interpolated fields resolved prior to being passed to any pre-call functions. Otherwise, the interpolation occurs after the pre-call functions are called.

excludeOptional[str | Iterable[str]]

Specifies one or more parameter names in the function’s signature that will not be extracted from input configs by the zen-wrapped function.

A single string of comma-separated names can be specified.

run_in_contextbool, optional (default=False)

If True, the zen-wrapped function - and the pre_call function, if specified - is run in a copied contextvars.Context; i.e. changes made to any contextvars.ContextVar will be isolated to that call of the wrapped function.

run_in_context is not supported for async functions.

ZenWrapperType[hydra_zen.wrapper.Zen], optional (default=Zen)

If specified, a subclass of Zen that customizes the behavior of the wrapper.

instantiation_wrapperOptional[Callable[[F2], F2]], optional (default=None)

If specified, a function that wraps the task function and all instantiation-targets before they are called.

This can be used to introduce a layer of validation or logging to all instantiation calls in your application.

Returns:
wrappedZen[Sig, R]

A callable with signature (conf: ConfigLike, \) -> R

The wrapped function is an instance of hydra_zen.wrapper.Zen and accepts a single Hydra config (a dataclass, dictionary, or omegaconf container). The parameters of the wrapped function’s signature determine the fields that are extracted from the config; only those fields that are accessed will be resolved and instantiated.

See also

hydra_zen.wrapper.Zen

Implements the wrapping logic that is exposed by hydra_zen.zen.

Notes

The following pseudo code conveys the core functionality of zen:

from hydra_zen import instantiate as inst

def zen(func):
    sig = get_signature(func)

    def wrapped(cfg):
        cfg = resolve_interpolated_fields(cfg)
        kwargs = {p: inst(getattr(cfg, p)) for p in sig}
        return func(**kwargs)
    return wrapped

The presence of a parameter named “zen_cfg” in the wrapped function’s signature will cause zen to pass the full, resolved config to that field. This specific parameter name can be overridden via Zen.CFG_NAME.

Specifying config_path via Zen.hydra_main is only supported for Hydra 1.3.0+.

Examples

Basic Usage

>>> from hydra_zen import zen, make_config, builds
>>> def f(x, y): return x + y
>>> zen_f = zen(f)

The wrapped function – zen_f – accepts a single argument: a Hydra-compatible config that has the attributes “x” and “y”:

>>> zen_f
zen[f(x, y)](cfg, /)

“Configs” – dataclasses, dictionaries, and omegaconf containers – are acceptable inputs to zen-wrapped functions. Interpolated fields will be resolved and sub-configs will be instantiated. Excess fields in the config are unused.

>>> zen_f(make_config(x=1, y=2, z=999))  # z is not used
3
>>> zen_f(dict(x=2, y="${x}"))  # y will resolve to 2
4
>>> zen_f(dict(x=2, y=builds(int, 10)))  # y will instantiate to 10
12

The wrapped function can be accessed directly

>>> zen_f.func
<function __main__.f(x, y)>
>>> zen_f.func(-1, 1)
0

zen is compatible with partial’d functions.

>>> from functools import partial
>>> pf = partial(lambda x, y: x + y, x=10)
>>> zpf = zen(pf)
>>> zpf(dict(y=1))
11
>>> zpf(dict(x='${y}', y=1))
2

One can specify exclude to prevent particular variables from being extracted from a config:

>>> def g(x=1, y=2): return (x, y)
>>> cfg = {"x": -10, "y": -20}
>>> zen(g)(cfg)  # extracts x & y from config to call f
(-10, -20)
>>> zen(g, exclude="x")(cfg)  # extracts y from config to call f(x=1, ...)
(1, -20)
>>> zen(g, exclude="x,y")(cfg)  # defers to f's defaults
(1, 2)

Populating a **kwargs field via unpack_kwargs=True:

>>> def h(a, **kw):
...     return a, kw
>>> cfg = dict(a=1, b=22)
>>> zen(h, unpack_kwargs=False)(cfg)
(1, {})
>>> zen(h, unpack_kwargs=True)(cfg)
(1, {'b': 22})

Passing Through The Full Input Config

Some task functions require complete access to the full config to gain access to sub-configs. One can specify the field named zen_config in their task function’s signature to signal zen that it should pass the full config to that parameter .

>>> def zf(x: int, zen_cfg):
...     return x, zen_cfg
>>> zen(zf)(dict(x=1, y="${x}", foo="bar"))
(1, {'x': 1, 'y': 1, 'foo': 'bar'})

Including a pre-call function

Given that a zen-wrapped function will automatically extract and instantiate config fields upon being called, it can be necessary to include a pre-call step that occurs prior to any instantiation. zen can be passed one or more pre-call functions that will be called with the input config as a precursor to calling the decorated function.

Consider the following scenario where the instantiating the input config involves drawing a random value, which we want to be made deterministic with a configurable seed. We will use a pre-call function to seed the RNG prior to the instantiation.

>>> import random
>>> from hydra_zen import builds, zen
>>>
>>> def func(rand_val: int): return rand_val
>>>
>>> cfg = dict(
...         seed=0,
...         rand_val=builds(random.randint, 0, 10),
... )
>>> wrapped = zen(func, pre_call=lambda cfg: random.seed(cfg.seed))
>>> def f1(rand_val: int):
...     return rand_val
>>> zf1 = zen(pre_call=lambda cfg: random.seed(cfg.seed))(f1)
>>> [zf1(cfg) for _ in range(10)]
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6]

Using `zen` instead of `@hydra.main`

The object returned by zen provides a convenience method – Zen.hydra_main – to generate a CLI for a zen-wrapped task function:

# example.py
from hydra_zen import zen, store

@store(name="my_app")
def task(x: int, y: int):
    print(x + y)

if __name__ == "__main__":
    store.add_to_hydra_store()
    zen(task).hydra_main(config_name="my_app", config_path=None, version_base="1.2")
$ python example.py x=1 y=2
3

Validating input configs

An input config can be validated against a zen-wrapped function – without calling said function – via the validate method.

>>> def f2(x: int): ...
>>> zen_f = zen(f2)
>>> zen_f.validate({"x": 1})  # OK
>>> zen_f.validate({"y": 1})  # Missing x
HydraZenValidationError: `cfg` is missing the following fields: x

Validation propagates through zen-wrapped pre-call functions:

>>> zen_f2 = zen(f2, pre_call=zen(lambda seed: None))
>>> zen_f2.validate({"x": 1, "seed": 10})  # OK
>>> zen_f2.validate({"x": 1})  # Missing seed as required by pre-call
HydraZenValidationError: `cfg` is missing the following fields: seed