Source code for hydra_zen.third_party.pydantic

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
# pyright: reportUnnecessaryTypeIgnoreComment=false
import functools
import inspect
from typing import Any, Callable, TypeVar, Union, cast

import pydantic as _pyd

_T = TypeVar("_T", bound=Callable[..., Any])

__all__ = ["validates_with_pydantic"]


if _pyd.__version__ >= "2.0":  # pragma: no cover
    _default_parser = _pyd.validate_call(
        config={"arbitrary_types_allowed": True}, validate_return=False  # type: ignore
    )
else:  # pragma: no cover
    _default_parser = _pyd.validate_arguments(
        config={"arbitrary_types_allowed": True, "validate_return": False}  # type: ignore
    )


def _constructor_as_fn(cls: Any) -> Any:
    """Makes a shim around a class constructor so that it is compatible with pydantic validation.

    Notes
    -----
    `pydantic.validate_call` mishandles class constructors; it expects that
    `cls`/`self` should be passed explicitly to the constructor. This shim
    corrects that.
    """

    @functools.wraps(cls)
    def wrapper_function(*args, **kwargs):
        return cls(*args, **kwargs)

    annotations = getattr(cls, "__annotations__", {})

    # In a case like:
    # class A:
    #   x: int
    #   def __init__(self, y: int): ...
    #
    #  y will not be in __annotations__ but it should be in the signature,
    #  so we add it to the annotations.

    sig = inspect.signature(cls)
    for p, v in sig.parameters.items():
        if p not in annotations:
            annotations[p] = v.annotation
    wrapper_function.__annotations__ = annotations

    return wrapper_function


def _get_signature(x: Any) -> Union[None, inspect.Signature]:
    try:
        return inspect.signature(x)
    except Exception:
        return None


[docs] def pydantic_parser(target: _T, *, parser: Callable[[_T], _T] = _default_parser) -> _T: """A target-wrapper that adds pydantic parsing to the target. This can be passed to `instantiate` as a `_target_wrapper_` to add pydantic parsing to the (recursive) instantiation of the target. Parameters ---------- target : Callable parser : Type[pydantic.validate_arguments], optional A configured instance of pydantic's validation decorator. The default validator that we provide specifies: - arbitrary_types_allowed: True Examples -------- .. code-block:: python from hydra_zen import builds, instantiate from hydra_zen.third_party.pydantic import pydantic_parser from pydantic import PositiveInt def f(x: PositiveInt): return x good_conf = builds(f, x=10) bad_conf = builds(f, x=-3) >>> instantiate(good_conf, _target_wrapper_=pydantic_parser) 10 >>> instantiate(bad_conf, _target_wrapper_=pydantic_parser) ValidationError: 1 validation error for f (...) This also enables type conversion / parsing. E.g. Hydra can only produce lists from the CLI, but this parsing layer can convert them based on the annotated type. (Note: this only works for pydantic v2 and higher.) >>> def g(x: tuple): return x >>> conf = builds(g, x=[1, 2, 3]) >>> instantiate(conf, _target_wrapper_=pydantic_parser) (1, 2, 3) """ if inspect.isbuiltin(target): return target if not (_get_signature(target)): return target if inspect.isclass(target): return cast(_T, parser(_constructor_as_fn(target))) return parser(target)
[docs] def validates_with_pydantic( obj: _T, *, validator: Callable[[_T], _T] = _default_parser ) -> _T: """ .. deprecated:: 0.13.0 Use `hydra_zen.third_party.pydantic.pydantic_parser` instead. """ return pydantic_parser(obj, parser=validator)