Source code for ax.utils.common.typeutils_torch
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Union
import torch
from ax.utils.common.typeutils import checked_cast
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]def torch_type_to_str(value: Any) -> str:
"""Converts torch types, commonly used in Ax, to string representations."""
if isinstance(value, torch.dtype):
return str(value)
if isinstance(value, torch.device):
return checked_cast(str, value.type)
raise ValueError(f"Object {value} was of unexpected torch type.")
[docs]def torch_type_from_str(
identifier: str, type_name: str
) -> Union[torch.dtype, torch.device]:
if type_name == "device":
return torch.device(identifier)
if type_name == "dtype":
return getattr(torch, identifier[6:])
raise ValueError(f"Unexpected type: {type_name} for identifier: {identifier}.")