Source code for ax.utils.common.serialization
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import inspect
import pydoc
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Type
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
TDecoderRegistry = Dict[str, Type]
# pyre-fixme[33]: `TClassDecoderRegistry` cannot alias to a type containing `Any`.
TClassDecoderRegistry = Dict[str, Callable[[Dict[str, Any]], Any]]
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]def named_tuple_to_dict(data: Any) -> Any:
"""Recursively convert NamedTuples to dictionaries."""
if isinstance(data, dict):
return {key: named_tuple_to_dict(value) for key, value in data.items()}
elif isinstance(data, list):
return [named_tuple_to_dict(value) for value in data]
elif _is_named_tuple(data):
return {
key: named_tuple_to_dict(value) for key, value in data._asdict().items()
elif isinstance(data, tuple):
return tuple(named_tuple_to_dict(value) for value in data)
return data
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def _is_named_tuple(x: Any) -> bool:
"""Return True if x is an instance of NamedTuple."""
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False # pragma nocover
return all(isinstance(n, str) for n in f)
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
[docs]def callable_to_reference(callable: Callable) -> str:
"""Obtains path to the callable of form <module>.<name>."""
if not isinstance(callable, (FunctionType, type)):
raise TypeError(f"Expected to encode function or class, got: {callable}.")
name = f"{callable.__module__}.{callable.__qualname__}"
assert pydoc.locate(name) is callable
return name
except Exception as err:
raise TypeError(
f"Callable {callable.__qualname__} is not properly exposed in "
f"{callable.__module__} (exception: {err})."
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
[docs]def callable_from_reference(path: str) -> Callable:
"""Retrieves a callable by its path."""
return pydoc.locate(path) # pyre-ignore[7]
[docs]def serialize_init_args(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
obj: Any,
exclude_fields: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Given an object, return a dictionary of the arguments that are
needed by its constructor.
properties = {}
exclude_args = ["self", "args", "kwargs"] + (exclude_fields or [])
signature = inspect.signature(obj.__class__.__init__)
for arg in signature.parameters:
if arg in exclude_args:
value = getattr(obj, arg)
except AttributeError:
raise AttributeError(
f"{obj.__class__} is missing a value for {arg}, "
f"which is needed by its constructor."
properties[arg] = value
return properties
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
[docs]class SerializationMixin:
[docs] @classmethod
def serialize_init_args(cls, obj: SerializationMixin) -> Dict[str, Any]:
"""Serialize the properties needed to initialize the object.
Used for storage.
return serialize_init_args(obj=obj)
[docs] @classmethod
def deserialize_init_args(
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
object. Used for storage.
return extract_init_args(args=args, class_=cls)