Source code for ax.utils.common.typeutils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch


T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")
X = TypeVar("X")
Y = TypeVar("Y")


[docs]def not_none(val: Optional[T]) -> T: """ Unbox an optional type. Args: val: the value to cast to a non ``None`` type Retruns: V: ``val`` when ``val`` is not ``None`` Throws: ValueError if ``val`` is ``None`` """ if val is None: raise ValueError("Argument to `not_none` was None.") return val
[docs]def checked_cast(typ: Type[T], val: V) -> T: """ Cast a value to a type (with a runtime safety check). Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type. Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument, but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is not of the expected type. The type passed as an argument should be a python class. Args: typ: the type to cast to val: the value that we are casting Returns: the ``val`` argument, unchanged .. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast """ if not isinstance(val, typ): raise ValueError(f"Value was not of type {type}:\n{val}") return val
[docs]def checked_cast_optional(typ: Type[T], val: Optional[V]) -> Optional[T]: """Calls checked_cast only if value is not None.""" if val is None: return val return checked_cast(typ, val)
[docs]def checked_cast_list(typ: Type[T], l: List[V]) -> List[T]: """Calls checked_cast on all items in a list.""" new_l = [] for val in l: val = checked_cast(typ, val) new_l.append(val) return new_l
[docs]def checked_cast_dict( key_typ: Type[K], value_typ: Type[V], d: Dict[X, Y] ) -> Dict[K, V]: """Calls checked_cast on all keys and values in the dictionary.""" new_dict = {} for key, val in d.items(): val = checked_cast(value_typ, val) key = checked_cast(key_typ, key) new_dict[key] = val return new_dict
# pyre-fixme[34]: `T` isn't present in the function's parameters.
[docs]def checked_cast_to_tuple(typ: Tuple[Type[V], ...], val: V) -> T: """ Cast a value to a union of multiple types (with a runtime safety check). This function is similar to `checked_cast`, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple. Args: typ: the tuple of types to cast to val: the value that we are casting Returns: the ``val`` argument, unchanged """ if not isinstance(val, typ): raise ValueError(f"Value was not of type {type!r}:\n{val!r}") return val
[docs]def numpy_type_to_python_type(value: Any) -> Any: """If `value` is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values. """ if isinstance(value, np.integer): value = int(value) # pragma: nocover (covered by generator tests) if isinstance(value, np.floating): value = float(value) # pragma: nocover (covered by generator tests) return value
[docs]def torch_type_to_str(value: Any) -> str: """Converts torch types, commonly used in Ax, to string representations.""" if isinstance(value, torch.dtype): return str(value) if isinstance(value, torch.device): return checked_cast(str, value.type) # pyre-fixme[16]: device has to attr. type raise ValueError(f"Object {value} was of unexpected torch type.")
[docs]def torch_type_from_str( identifier: str, type_name: str ) -> Union[torch.dtype, torch.device]: if type_name == "device": return torch.device(identifier) if type_name == "dtype": return getattr(torch, identifier[6:]) raise ValueError(f"Unexpected type: {type_name} for identifier: {identifier}.")