Source code for hydra_zen.structured_configs._type_guards

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
# pyright: strict
from dataclasses import MISSING
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, Type, Union

from typing_extensions import TypeGuard

from hydra_zen.funcs import get_obj, zen_processing
from hydra_zen.structured_configs._utils import safe_name
from hydra_zen.typing import Builds, Just, PartialBuilds
from hydra_zen.typing._implementations import DataClass_, HasTarget, HasTargetInst

from ._globals import (
    JUST_FIELD_NAME,
    PARTIAL_FIELD_NAME,
    TARGET_FIELD_NAME,
    ZEN_PARTIAL_FIELD_NAME,
    ZEN_PROCESSING_LOCATION,
    ZEN_TARGET_FIELD_NAME,
)

__all__ = ["is_partial_builds", "uses_zen_processing", "is_dataclass"]

# We need to check if things are Builds, Just, PartialBuilds to a higher
# fidelity than is provided by `isinstance(..., <Protocol>)`. I.e. we want to
# check that the desired attributes *and* that their values match those of the
# protocols. Failing to heed this would, for example, lead to any `Builds` that
# happens to have a `path` attribute to be treated as `Just` in `get_target`.
#
# The following functions perform these desired checks. Note that they do not
# require that the provided object be a dataclass; this enables compatibility
# with omegaconf containers.
#
# These are not part of the public API for now, but they may be in the future.


def safe_getattr(obj: Any, field: str, *default: Any) -> Any:
    # We must access slotted class-attributes from a dataclass type
    # via its `__dataclass_fields__`. Otherwise we will get a member
    # descriptor

    assert len(default) < 2
    if (
        hasattr(obj, "__slots__")
        and isinstance(obj, type)
        and is_dataclass(obj)
        and field in obj.__slots__  # type: ignore
    ):
        try:
            _field = obj.__dataclass_fields__[field]
            if _field.default_factory is not MISSING or _field.default is MISSING:
                raise AttributeError

            return _field.default

        except (KeyError, AttributeError):
            if default:
                return default[0]

            raise AttributeError(
                f"type object '{safe_name(obj)}' has no attribute '{field}'"
            )

    return getattr(obj, field, *default)


def _get_target(x: Union[HasTarget, HasTargetInst]) -> Any:
    return safe_getattr(x, TARGET_FIELD_NAME)


def is_builds(x: Any) -> TypeGuard[Builds[Any]]:
    return hasattr(x, TARGET_FIELD_NAME)


def is_just(x: Any) -> TypeGuard[Just[Any]]:
    if is_builds(x) and hasattr(x, JUST_FIELD_NAME):
        attr = _get_target(x)
        if attr == _get_target(Just) or attr is get_obj:
            return True
        else:
            # ensures we convert this branch in tests
            return False
    return False


if TYPE_CHECKING:  # pragma: no cover

    def is_dataclass(obj: Any) -> TypeGuard[Union[DataClass_, Type[DataClass_]]]: ...

else:
    from dataclasses import is_dataclass


def is_old_partial_builds(x: Any) -> bool:  # pragma: no cover
    # We don't care about coverage here.
    # This will only be used in `get_target` and we'll be sure to cover that branch
    if is_builds(x) and hasattr(x, "_partial_target_"):
        attr = _get_target(x)
        if (attr == "hydra_zen.funcs.partial" or attr is partial) and is_just(
            safe_getattr(x, "_partial_target_")
        ):
            return True
        else:  # pragma: no cover
            return False
    return False


[docs] def uses_zen_processing(x: Any) -> TypeGuard[Builds[Any]]: """Returns `True` if the input is a targeted structured config that relies on zen-processing features during its instantiation process. See notes for more details Parameters ---------- x : Any Returns ------- uses_zen : bool Notes ----- In order to support zen :ref:`meta-fields <meta-field>` and :ref:`zen wrappers <zen-wrapper>`, hydra-zen redirects Hydra to an intermediary function – `hydra_zen.funcs.zen_processing` – during instantiation; i.e. `zen_processing` is made to be the `_target_` of the config and `_zen_target` indicates the object that is ultimately being configured for instantiation. Examples -------- >>> from hydra_zen import builds, uses_zen_processing, to_yaml >>> ConfA = builds(dict, a=1) >>> ConfB = builds(dict, a=1, zen_partial=True) >>> ConfC = builds(dict, a=1, zen_wrappers=lambda x: x) >>> ConfD = builds(dict, a=1, zen_meta=dict(hidden_field=None)) >>> ConfE = builds(dict, a=1, zen_meta=dict(hidden_field=None), zen_partial=True) >>> uses_zen_processing(ConfA) False >>> uses_zen_processing(ConfB) False >>> uses_zen_processing(ConfC) True >>> uses_zen_processing(ConfD) True >>> uses_zen_processing(ConfE) True Demonstrating the indirection that is used to facilitate zen-processing features. >>> print(to_yaml(ConfE)) _target_: hydra_zen.funcs.zen_processing _zen_target: builtins.dict _zen_partial: true _zen_exclude: - hidden_field a: 1 hidden_field: null """ if not is_builds(x) or not hasattr(x, ZEN_TARGET_FIELD_NAME): return False attr = _get_target(x) if attr != ZEN_PROCESSING_LOCATION and attr is not zen_processing: return False return True
[docs] def is_partial_builds(x: Any) -> TypeGuard[PartialBuilds[Any]]: """ Returns `True` if the input is a targeted structured config that entails partial instantiation, either via `_partial_=True` [1]_ or via `_zen_partial=True`. Parameters ---------- x : Any Returns ------- is_partial_config : bool References ---------- .. [1] https://hydra.cc/docs/advanced/instantiate_objects/overview/#partial-instantiation See Also -------- uses_zen_processing Examples -------- >>> from hydra_zen import is_partial_builds An example involving a basic structured config >>> from dataclasses import dataclass >>> @dataclass ... class A: ... _target_ : str = 'builtins.int' ... _partial_ : bool = True >>> is_partial_builds(A) True >>> is_partial_builds(A(_partial_=False)) False An example of a config that leverages partial instantiation via zen-processing >>> from hydra_zen import builds, uses_zen_processing, instantiate >>> Conf = builds(int, 0, zen_partial=True, zen_meta=dict(a=1)) >>> hasattr(Conf, "_partial_") False >>> uses_zen_processing(Conf) True >>> is_partial_builds(Conf) True >>> instantiate(Conf) functools.partial(<class 'int'>, 0) """ if is_builds(x): return ( # check if partial'd config via Hydra safe_getattr(x, PARTIAL_FIELD_NAME, False) is True ) or ( # check if partial'd config via `zen_processing` uses_zen_processing(x) and (safe_getattr(x, ZEN_PARTIAL_FIELD_NAME, False) is True) ) return False
class HasOrigin(Protocol): __origin__: Type[Any] def is_generic_type(x: Any) -> TypeGuard[HasOrigin]: return ( hasattr(x, "__origin__") and hasattr(x, "__args__") and hasattr(x, "__parameters__") and isinstance(x.__origin__, type) )