Source code for rai_toolbox.losses._utils

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from functools import wraps
from typing import Any, Callable, TypeVar, cast

from typing_extensions import Protocol


class Negateable(Protocol):
    def __neg__(self) -> Any:  # pragma: no cover
        ...


T = TypeVar("T", bound=Callable[..., Negateable])


[docs] def negate(func: T) -> T: """A wrapper that negates (applies the `-` operator) to the function's output. Parameters ---------- func : Callable[..., Negateable] Examples -------- >>> from rai_toolbox import negate >>> f = negate(lambda x: 2 * x) >>> f(1) -2 """ @wraps(func) def wrapper(*args, **kwargs): return -func(*args, **kwargs) return cast(T, wrapper)