Source code for hydra_zen.typing._implementations

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT

# pyright: strict

import sys
import types
from dataclasses import _MISSING_TYPE  # pyright: ignore[reportPrivateUsage]
from datetime import timedelta
from enum import Enum
from pathlib import Path, PosixPath, WindowsPath
from typing import (
    TYPE_CHECKING,
    Any,
    ByteString,
    Callable,
    ClassVar,
    Dict,
    FrozenSet,
    List,
    Mapping,
    NewType,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from omegaconf import DictConfig, ListConfig
from typing_extensions import (
    Final,
    Literal,
    ParamSpec,
    Protocol,
    Required,
    Self,
    TypeAlias,
    TypedDict,
    runtime_checkable,
)

__all__ = [
    "Just",
    "Builds",
    "PartialBuilds",
    "Partial",
    "Importable",
    "SupportedPrimitive",
    "ZenWrappers",
    "ZenPartialBuilds",
    "HydraPartialBuilds",
    "ZenConvert",
]

P = ParamSpec("P")
R = TypeVar("R")


class EmptyDict(TypedDict):
    pass


T = TypeVar("T", covariant=True)
T2 = TypeVar("T2")
T3 = TypeVar("T3")

T4 = TypeVar("T4", bound=Callable[..., Any])


InstOrType: TypeAlias = Union[T, Type[T]]


if TYPE_CHECKING:
    from dataclasses import Field  # provided by typestub but not generic at runtime
else:

    class Field(Protocol[T2]):
        name: str
        type: Type[T2]
        default: T2
        default_factory: Callable[[], T2]
        repr: bool
        hash: Optional[bool]
        init: bool
        compare: bool
        metadata: Mapping[str, Any]


@runtime_checkable
class Partial(Protocol[T2]):
    """A protocol that matches against `functools.partial`"""

    __call__: Callable[..., T2]

    @property
    def func(self) -> Callable[..., T2]: ...

    @property
    def args(self) -> Tuple[Any, ...]: ...

    @property
    def keywords(self) -> Dict[str, Any]: ...

    def __new__(
        cls: Type[Self], __func: Callable[..., T2], *args: Any, **kwargs: Any
    ) -> Self: ...

    if TYPE_CHECKING and sys.version_info >= (3, 9):  # pragma: no cover

        def __class_getitem__(cls, item: Any) -> types.GenericAlias: ...


InterpStr = NewType("InterpStr", str)


class DataClass_(Protocol):
    # doesn't provide __init__, __getattribute__, etc.
    __dataclass_fields__: ClassVar[Dict[str, Field[Any]]]


class DataClass(DataClass_, Protocol):
    def __init__(self, *args: Any, **kwargs: Any) -> None: ...

    def __getattribute__(self, __name: str) -> Any: ...

    def __setattr__(self, __name: str, __value: Any) -> None: ...


@runtime_checkable
class HasTarget(Protocol):
    _target_: ClassVar[str]


@runtime_checkable
class HasTargetInst(Protocol):
    _target_: str


@runtime_checkable
class Builds(DataClass, Protocol[T]):
    _target_: ClassVar[str]


class BuildsWithSig(Builds[T], Protocol[T, P]):
    def __init__(self, *args: P.args, **kwds: P.kwargs): ...


@runtime_checkable
class Just(Builds[T], Protocol[T]):
    path: str  # interpolated string for importing obj
    _target_: ClassVar[str] = "hydra_zen.funcs.get_obj"


class ZenPartialMixin(Protocol[T]):
    _zen_target: ClassVar[str]
    _zen_partial: ClassVar[Literal[True]] = True


class HydraPartialMixin(Protocol[T]):
    _partial_: ClassVar[Literal[True]] = True


@runtime_checkable
class ZenPartialBuilds(Builds[T], ZenPartialMixin[T], Protocol[T]):
    _target_: ClassVar[str] = "hydra_zen.funcs.zen_processing"


@runtime_checkable
class HydraPartialBuilds(Builds[T], HydraPartialMixin[T], Protocol[T]): ...


# Necessary, but not sufficient, check for PartialBuilds; useful for creating
# non-overlapping overloads
IsPartial: TypeAlias = Union[ZenPartialMixin[T], HydraPartialMixin[T]]

PartialBuilds: TypeAlias = Union[ZenPartialBuilds[T], HydraPartialBuilds[T]]

AnyBuilds: TypeAlias = Union[Builds[T], BuildsWithSig[T, Any]]

Importable = TypeVar("Importable", bound=Callable[..., Any])

_HydraPrimitive: TypeAlias = Union[
    bool, None, int, float, str, ByteString, Path, WindowsPath, PosixPath
]

HydraSupportedType = Union[
    _HydraPrimitive,
    DataClass_,
    Type[DataClass_],
    ListConfig,
    DictConfig,
    Enum,
    _MISSING_TYPE,
    # Even thought this is redundant with Sequence, it seems to
    # be needed for pyright to do proper checking of tuple contents
    Tuple["HydraSupportedType", ...],
    Sequence["HydraSupportedType"],
    Mapping[Any, "HydraSupportedType"],
]
"""Describes types that are compatible with Hydra -- they can be used in
configs provided to Hydra."""

_SupportedViaBuilds = Union[
    Partial[Any],
    range,
    Set[Any],
    timedelta,
    types.SimpleNamespace,
]

_SupportedPrimitive: TypeAlias = Union[
    _HydraPrimitive,
    ListConfig,
    DictConfig,
    Callable[..., Any],
    Enum,
    DataClass_,
    complex,
    _SupportedViaBuilds,
    EmptyDict,  # not covered by Mapping[..., ...]]
]


SupportedPrimitive: TypeAlias = Union[
    _SupportedPrimitive,
    FrozenSet["SupportedPrimitive"],
    # Even thought this is redundant with Sequence, it seems to
    # be needed for pyright to do proper checking of tuple contents
    Tuple["SupportedPrimitive", ...],
    # Mutable generic containers need to be invariant, so
    # we have to settle for Sequence/Mapping. While this
    # is overly permissive in terms of sequence-type, it
    # at least affords quality checking of sequence content
    Sequence["SupportedPrimitive"],
    # Mapping is covariant only in value
    Mapping[Any, "SupportedPrimitive"],
]
"""Describes types that are natively supported by hydra-zen's config-creation
functions."""

CustomConfigType: TypeAlias = Union[
    T2,
    HydraSupportedType,
    Tuple["CustomConfigType[T2]", ...],
    Sequence["CustomConfigType[T2]"],
    Mapping[Any, "CustomConfigType[T2]"],
    Partial["CustomConfigType[T2]"],
    Partial[T2],
]
"""The type `CustomConfigType[MyType]` describes: `MyType`, all hydra-zen config-compatible types, and all hydra-zen compatible containers containing said types.

This is use for parameterizing `hydra_zen.BuildsFn` with custom type information. Example::

   from hydra_zen import BuildsFn
   from hydra_zen.typing import CustomConfigType

   class MyType: ...
   class BadType: ...

   class MyBuilds(BuildsFn[CustomConfigType[MyType]]):
       ...

   builds = MyBuilds.builds

   builds(dict, x=MyType(), y=[1, MyType()])  # type-checker: OK
   builds(dict, x=BadType(), y=[1, MyType()]) # type-checker: Bad!
   builds(dict, x=MyType(), y=[1, BadType()]) # type-checker: Bad!
"""


ZenWrapper: TypeAlias = Union[
    None,
    Builds[Callable[[T4], T4]],
    PartialBuilds[Callable[[T4], T4]],
    Just[Callable[[T4], T4]],
    Type[Builds[Callable[[T4], T4]]],
    Type[PartialBuilds[Callable[[T4], T4]]],
    Type[Just[Callable[[T4], T4]]],
    Callable[[T4], T4],
    str,
]


ZenWrappers: TypeAlias = Union[ZenWrapper[T4], Sequence[ZenWrapper[T4]]]


DefaultsList = List[
    Union[
        str, DataClass_, Type[DataClass_], Mapping[str, Union[None, str, Sequence[str]]]
    ]
]


# Lists all zen-convert settings and their types. Not part of public API
class AllConvert(TypedDict, total=True):
    dataclass: bool
    flat_target: bool


# used for runtime type-checking
convert_types: Final = {"dataclass": bool, "flat_target": bool}

GroupName: TypeAlias = Optional[str]
NodeName: TypeAlias = str
Node: TypeAlias = Any


# TODO: make immutable
class StoreEntry(TypedDict):
    name: NodeName
    group: GroupName
    package: Optional[str]
    provider: Optional[str]
    node: Node


[docs] class ZenConvert(TypedDict, total=False): """A TypedDict that provides a type-checked interface for specifying zen-convert options that configure the hydra-zen config-creation functions (e.g., `builds`, `just`, and `make_config`). Note that, at runtime, `ZenConvert` is simply a dictionary with type-annotations. There is no enforced runtime validation of its keys and values. Parameters ---------- flat_target: bool If `True` (default), `builds(builds(f))` is equivalent to `builds(f)`. I.e. the second `builds` call will use the `_target_` field of its input, if it exists. dataclass : bool If `True` (default) any dataclass type/instance without a `_target_` field is automatically converted to a targeted config that will instantiate to that type/ instance. Otherwise the dataclass type/instance will be passed through as-is. Note that this only works with statically-defined dataclass types, whereas :func:`~hydra_zen.make_config` and :py:func:`dataclasses.make_dataclass` dynamically generate dataclass types. Additionally, this feature is not compatible with a dataclass instance whose type possesses an `InitVar` field. Examples -------- >>> from hydra_zen.typing import ZenConvert as zc >>> zc() {} >>> zc(dataclass=True) {"dataclass": True} >>> # static type-checker will raise, but runtime will not >>> zc(apple=1) # type: ignore {"apple": 1} **Configuring dataclass auto-config behaviors** >>> from hydra_zen import instantiate as I >>> from hydra_zen import builds, just >>> from dataclasses import dataclass >>> @dataclass ... class B: ... x: int >>> b = B(x=1) >>> I(just(b)) B(x=1) >>> I(just(b, zen_convert=zc(dataclass=False))) # returns omegaconf.DictConfig {"x": 1} >>> I(builds(dict, y=b)) {'y': B(x=1)} >>> I(builds(dict, y=b, zen_convert=zc(dataclass=False))) # returns omegaconf.DictConfig {'y': {'x': 1}} >>> I(make_config(y=b)) # returns omegaconf.DictConfig {'y': {'x': 1}} >>> I(make_config(y=b, zen_convert=zc(dataclass=True), hydra_convert="all")) {'y': B(x=1)} Auto-config support does not work with dynamically-generated dataclass types >>> just(make_config(z=1)) HydraZenUnsupportedPrimitiveError: ... >>> I(just(make_config(z=1), zen_convert=zc(dataclass=False))) {'z': 1} A dataclass with a `_target_` field will not be converted: >>> @dataclass ... class BuildsStr: ... _target_: str = 'builtins.str' ... >>> BuildsStr is just(BuildsStr) True >>> (builds_str := BuildsStr()) is just(builds_str) True """ flat_target: bool dataclass: bool
class _AllPyDataclassOptions(TypedDict, total=False): cls_name: str namespace: Optional[Dict[str, Any]] bases: Tuple[Type[DataClass_], ...] init: bool repr: bool eq: bool order: bool unsafe_hash: bool frozen: bool class _Py310Dataclass(_AllPyDataclassOptions, total=False): # py310+ match_args: bool kw_only: bool slots: bool class _Py311Dataclass(_Py310Dataclass, total=False): weakref_slot: bool class _Py312Dataclass(_Py311Dataclass, total=False): module: Optional[str] if sys.version_info < (3, 10): _StrictDataclassOptions = _AllPyDataclassOptions elif sys.version_info < (3, 11): _StrictDataclassOptions = _Py310Dataclass elif sys.version_info < (3, 12): # pragma: no cover _StrictDataclassOptions = _Py311Dataclass else: # pragma: no cover _StrictDataclassOptions = _Py312Dataclass class StrictDataclassOptions(_StrictDataclassOptions): cls_name: Required[str] # type: ignore
[docs] class DataclassOptions(_Py312Dataclass, total=False): """Specifies dataclass-creation options via `builds`, `just` et al. Note that, unlike :func:`dataclasses.make_dataclass`, the default value for `unsafe_hash` is `True` for hydra-zen's dataclass-generating functions. See the documentation for :func:`dataclasses.make_dataclass` for more details [1]_. Options that are not supported by the local Python version will be ignored by hydra-zen's config-creation functions. Parameters ---------- cls_name : str, optional If specified, determines the name of the returned class object. Otherwise the name is inferred by hydra-zen. module : str, default='typing' If specified, sets the `__module__` attribute of the resulting dataclass. Specifying the module string-path in which the dataclass was generated, and specifying `cls_name` as the symbol that references the dataclass, will enable pickle-compatibility for that dataclass. See the Examples section for clarification. target : str, optional (unspecified by default) If specified, overrides the `_target_` field set on the resulting dataclass. init : bool, optional (default=True) If true (the default), a __init__() method will be generated. If the class already defines __init__(), this parameter is ignored. repr : bool, optional (default=True) If true (the default), a `__repr__()` method will be generated. The generated repr string will have the class name and the name and repr of each field, in the order they are defined in the class. Fields that are marked as being excluded from the repr are not included. For example: `InventoryItem(name='widget', unit_price=3.0, quantity_on_hand=10)`. eq : bool, optional (default=True) If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type. order : bool, optional (default=False) If true (the default is `False`), `__lt__()`, `__le__()`, `__gt__()`, and `__ge__()` methods will be generated. These compare the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type. If order is true and eq is false, a ValueError is raised. If the class already defines any of `__lt__()`, `__le__()`, `__gt__()`, or `__ge__()`, then `TypeError` is raised. unsafe_hash : bool, optional (default=False) If `False` (the default), a `__hash__()` method is generated according to how `eq` and `frozen` are set. If `eq` and `frozen` are both true, by default `dataclass()` will generate a `__hash__()` method for you. If `eq` is true and `frozen` is false, `__hash__() ` will be set to `None`, marking it unhashable. If `eq` is false, `__hash__()` will be left untouched meaning the `__hash__()` method of the superclass will be used (if the superclass is object, this means it will fall back to id-based hashing). frozen : bool, optional (default=False) If true (the default is `False`), assigning to fields will generate an exception. This emulates read-only frozen instances. match_args : bool, optional (default=True) (*New in version 3.10*) If true (the default is `True`), the `__match_args__` tuple will be created from the list of parameters to the generated `__init__()` method (even if `__init__()` is not generated, see above). If false, or if `__match_args__` is already defined in the class, then `__match_args__` will not be generated. kw_only : bool, optional (default=False) (*New in version 3.10*) If true (the default value is `False`), then all fields will be marked as keyword-only. slots : bool, optional (default=False) (*New in version 3.10*) If true (the default is `False`), `__slots__` attribute will be generated and new class will be returned instead of the original one. If `__slots__` is already defined in the class, then `TypeError` is raised. weakref_slot : bool, optional (default=False) (*New in version 3.11*) If true (the default is `False`), add a slot named “__weakref__”, which is required to make an instance weakref-able. It is an error to specify `weakref_slot=True` without also specifying `slots=True`. module : str | None If module is defined, the __module__ attribute of the dataclass is set to that value. By default, it is set to the module name of the caller. References ---------- .. [1] https://docs.python.org/3/library/dataclasses.html .. [2] https://docs.python.org/3/library/dataclasses.html#mutable-default-values Notes ----- This is a typed dictionary, which provides static type information (e.g. type checking and auto completion options) to tooling. Note, however, that it provides no runtime checking of its keys and values. Examples -------- >>> from hydra_zen.typing import DataclassOptions as Opts >>> from hydra_zen import builds, make_config, make_custom_builds_fn Creating a frozen config. >>> conf = make_config(x=1, zen_dataclass=Opts(frozen=True))() >>> conf.x = 2 FrozenInstanceError: cannot assign to field 'x' Creating a pickle-compatible config: The dynamically-generated classes created by `builds`, `make_config`, and `just` can be made pickle-compatible by specifying the name of the symbol that it is assigned to and the module in which it was defined. .. code-block:: python # contents of mylib/foo.py from pickle import dumps, loads from hydra_zen import builds DictConf = builds(dict, zen_dataclass={'module': 'mylib.foo', 'cls_name': 'DictConf'}) assert DictConf is loads(dumps(DictConf)) Using namespace to add a method to a config instance. >>> conf = make_config( ... x=100, ... zen_dataclass=Opts( ... namespace=dict(add_x=lambda self, y: self.x + y), ... ), ... )() >>> conf.add_x(2) 102 Dataclasse objects created by hydra-zen's config-creation functions will be created with `unsafe_hash=True` by default. This is in contrast with the default behavior of :py:func:`dataclasses.dataclass`. This is to help ensure smooth compatibility through Python 3.11, which changed the mutability checking rules for dataclasses [2]_. >>> from dataclasses import make_dataclass >>> DataClass = make_dataclass(fields=[], cls_name="A") >>> DataClass.__hash__ None >>> Conf = make_config(x=2) >>> Conf.__hash__ <function types.__create_fn__.<locals>.__hash__(self)> >>> UnHashConf = make_config(x=2, zen_dataclass=Opts(unsafe_hash=False)) >>> UnHashConf.__hash__ None """ module: Optional[str] target: str target_repr: bool
def _permitted_keys(typed_dict: Any) -> FrozenSet[str]: return typed_dict.__required_keys__ | typed_dict.__optional_keys__ DEFAULT_DATACLASS_OPTIONS = DataclassOptions(unsafe_hash=True) PERMITTED_DATACLASS_OPTIONS = _permitted_keys(DataclassOptions) UNSUPPORTED_DATACLASS_OPTIONS = ( _permitted_keys(_Py312Dataclass) - {"module"} ) - _permitted_keys(StrictDataclassOptions) del _AllPyDataclassOptions, _Py310Dataclass, _Py311Dataclass, _Py312Dataclass