Source code for hydra_zen.third_party.beartype

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import inspect
from typing import Any, Callable, TypeVar, cast

import beartype as bt

from hydra_zen._utils.coerce import coerce_sequences

_T = TypeVar("_T", bound=Callable[..., Any])

__all__ = ["validates_with_beartype"]


[docs] def validates_with_beartype(obj: _T) -> _T: """Enables runtime type-checking of values, via the library ``beartype``. I.e. ``obj = validates_with_beartype(obj)`` adds runtime type-checking to all calls of ``obj(*args, **kwargs)``, based on the type-annotations specified in the signature of ``obj``. This is designed to be used as a "zen-wrapper"; see Examples for details. Parameters ---------- obj : Callable Returns ------- obj_w_validation : Callable A wrapped function, or a class whose init-method has been wrapped in-place. See Also -------- hydra_zen.third_party.pydantic.validates_with_pydantic Notes ----- ``beartype`` [1]_ must be installed as a separate dependency to leverage this validator. Refer to beartype's documentation [2]_ to see what varieties of types it does and does not support. Using ``validates_with_beartype`` as a ``zen_wrapper`` will create a dependency on beartype among resulting yamls: these yamls will also be validated by beartype upon instantiation. **Data-Coercion Behavior** hydra-zen adds a data-coercion step that is not performed by ``beartype``. This only impacts fields in ``obj`` annotated with a (non-string) sequence-type annotation, which are passed list-type data. All other fields rely solely on beartype's native behavior. E.g. a field with a ``Tuple``-annotation, if passed a list, will see that list be cast to a tuple. See the Examples section for more details. References ---------- .. [1] https://github.com/beartype/beartype .. [2] https://github.com/beartype/beartype#compliance Examples -------- **Basic usage** >>> from hydra_zen.third_party.beartype import validates_with_beartype >>> from beartype.cave import ScalarTypes >>> def f(x: ScalarTypes): return x # a scalar is any real-valued number >>> f([1, 2]) # doesn't catch bad input [1, 2] >>> val_f = validates_with_beartype(f) # f + validation >>> val_f([1, 2]) BeartypeCallHintPepParamException: @beartyped f() parameter x=[1, 2] violates type hint [...] Applying `validates_with_beartype` to a class-object will wrap its ``__init__`` method in-place. >>> class A: ... def __init__(self, x: ScalarTypes): ... >>> validates_with_beartype(A) # wrapping occurs in-place __main__.A >>> A([1, 2]) BeartypeCallHintPepParamException: @beartyped A.__init__() parameter x=[1, 2] violates type hint [...] **Adding beartype validation to configs** This is designed to be used with the ``zen_wrappers`` feature of `builds`. >>> from hydra_zen import builds, instantiate >>> Conf = builds(f, populate_full_signature=True, zen_wrappers=validates_with_beartype) Instantiating ``Conf`` will prompt ``beartype`` to check the types of configured parameters against the corresponding annotations on ``f``. >>> instantiate(Conf, x=10) # 10 is a scalar: ok! 10 >>> instantiate(Conf, x=[1, 2]) # [1, 2] is not a scalar: roar! BeartypeCallHintPepParamException: @beartyped f() parameter x=[1, 2] violates type hint [...] Consider using :func:`~hydra_zen.make_custom_builds_fn` to add validation to all configs. **Sequence-coercion behavior for compatibility with Hydra** Note that sequence-coercion is enabled to ensure smooth compatibility with Hydra, as Hydra will interpret all (non-string) sequential data structures as lists. >>> def g(x: tuple): return x # note the annotation >>> g([1, 2, 3]) [1, 2, 3] >>> validates_with_beartype(g)([1, 2, 3]) # input: list, output: tuple (1, 2, 3) Only inputs of type list and ``ListConfig`` get cast in this way, since Hydra will read non-string sequential data from configs as either of these two types >>> validates_with_beartype(g)({1, 2, 3}) # input: a set BeartypeCallHintPepParamException: @beartyped g() parameter x={1, 2, 3} violates type hint [...] """ if inspect.isclass(obj) and hasattr(type, "__init__"): obj.__init__ = bt.beartype(obj.__init__) target = obj else: target = bt.beartype(obj) target = coerce_sequences(target) return cast(_T, target)