Source code for ax.utils.common.typeutils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from inspect import signature
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar

import numpy as np
from typeguard import check_type

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")
X = TypeVar("X")
Y = TypeVar("Y")


[docs]def not_none(val: Optional[T], message: Optional[str] = None) -> T: """ Unbox an optional type. Args: val: the value to cast to a non ``None`` type message: optional override of the default error message Returns: V: ``val`` when ``val`` is not ``None`` Throws: ValueError if ``val`` is ``None`` """ if val is None: raise ValueError(message or "Argument to `not_none` was None.") return val
[docs]def checked_cast(typ: Type[T], val: V, exception: Optional[Exception] = None) -> T: """ Cast a value to a type (with a runtime safety check). Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type. Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument, but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is not of the expected type. The type passed as an argument should be a python class. Args: typ: the type to cast to val: the value that we are casting exception: override exception to raise if typecheck fails Returns: the ``val`` argument, unchanged .. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast """ if not isinstance(val, typ): raise exception if exception is not None else ValueError( f"Value was not of type {typ}:\n{val}" ) return val
[docs]def checked_cast_optional(typ: Type[T], val: Optional[V]) -> Optional[T]: """Calls checked_cast only if value is not None.""" if val is None: return val return checked_cast(typ, val)
[docs]def checked_cast_list(typ: Type[T], old_l: List[V]) -> List[T]: """Calls checked_cast on all items in a list.""" new_l = [] for val in old_l: val = checked_cast(typ, val) new_l.append(val) return new_l
[docs]def checked_cast_dict( key_typ: Type[K], value_typ: Type[V], d: Dict[X, Y] ) -> Dict[K, V]: """Calls checked_cast on all keys and values in the dictionary.""" new_dict = {} for key, val in d.items(): val = checked_cast(value_typ, val) key = checked_cast(key_typ, key) new_dict[key] = val return new_dict
# pyre-fixme[34]: `T` isn't present in the function's parameters.
[docs]def checked_cast_to_tuple(typ: Tuple[Type[V], ...], val: V) -> T: """ Cast a value to a union of multiple types (with a runtime safety check). This function is similar to `checked_cast`, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple. Args: typ: the tuple of types to cast to val: the value that we are casting Returns: the ``val`` argument, unchanged """ if not isinstance(val, typ): raise ValueError(f"Value was not of type {type!r}:\n{val!r}") # pyre-fixme[7]: Expected `T` but got `V`. return val
[docs]def version_safe_check_type(argname: str, value: T, expected_type: Type[T]) -> None: """Excecute the check_type function if it has the expected signature, otherwise warn. This is done to support newer versions of typeguard with minimal loss of functionality for users that have dependency conflicts""" # Get the signature of the check_type function sig = signature(check_type) # Get the parameters of the check_type function params = sig.parameters # Check if the check_type function has the expected signature params = set(params.keys()) if all(arg in params for arg in ["argname", "value", "expected_type"]): check_type(argname, value, expected_type)
# pyre-fixme[3]: Return annotation cannot be `Any`. # pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]def numpy_type_to_python_type(value: Any) -> Any: """If `value` is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values. """ if isinstance(value, np.integer): value = int(value) # pragma: nocover (covered by generator tests) if isinstance(value, np.floating): value = float(value) # pragma: nocover (covered by generator tests) return value
# pyre-fixme[2]: Parameter annotation cannot be `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to # avoid runtime subscripting errors. def _argparse_type_encoder(arg: Any) -> Type: """ Transforms arguments passed to `optimizer_argparse.__call__` at runtime to construct the key used for method lookup as `tuple(map(arg_transform, args))`. This custom arg_transform allow type variables to be passed at runtime. """ # Allow type variables to be passed as arguments at runtime return arg if isinstance(arg, type) else type(arg)