Source code for ax.utils.common.serialization
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import pydoc
from abc import ABC
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Type
# https://stackoverflow.com/a/39235373
[docs]def named_tuple_to_dict(data: Any) -> Any:
"""Recursively convert NamedTuples to dictionaries."""
if isinstance(data, dict):
return {key: named_tuple_to_dict(value) for key, value in data.items()}
elif isinstance(data, list):
return [named_tuple_to_dict(value) for value in data]
elif _is_named_tuple(data):
return {
key: named_tuple_to_dict(value) for key, value in data._asdict().items()
}
elif isinstance(data, tuple):
return tuple(named_tuple_to_dict(value) for value in data)
else:
return data
# https://stackoverflow.com/a/2166841
def _is_named_tuple(x: Any) -> bool:
"""Return True if x is an instance of NamedTuple."""
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False # pragma nocover
return all(type(n) == str for n in f)
[docs]def callable_to_reference(callable: Callable) -> str:
"""Obtains path to the callable of form <module>.<name>."""
if not isinstance(callable, (FunctionType, type)):
raise TypeError(f"Expected to encode function or class, got: {callable}.")
name = f"{callable.__module__}.{callable.__qualname__}"
try:
assert pydoc.locate(name) is callable
return name
except Exception as err:
raise TypeError(
f"Callable {callable.__qualname__} is not properly exposed in "
f"{callable.__module__} (exception: {err})."
)
[docs]def callable_from_reference(path: str) -> Callable:
"""Retrieves a callable by its path."""
return pydoc.locate(path) # pyre-ignore[7]
# TODO: update signature to avoid shadowing python `object` fn.
[docs]def serialize_init_args(
object: Any, exclude_fields: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Given an object, return a dictionary of the arguments that are
needed by its constructor.
"""
properties = {}
exclude_args = ["self", "args", "kwargs"] + (exclude_fields or [])
signature = inspect.signature(object.__class__.__init__)
for arg in signature.parameters:
if arg in exclude_args:
continue
try:
value = getattr(object, arg)
except AttributeError:
raise AttributeError(
f"{object.__class__} is missing a value for {arg}, "
f"which is needed by its constructor."
)
properties[arg] = value
return properties
[docs]class SerializationMixin(ABC):
[docs] @classmethod
def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
"""Serialize the properties needed to initialize the object.
Used for storage.
"""
return serialize_init_args(object=obj)
[docs] @classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
object. Used for storage.
"""
return extract_init_args(args=args, class_=cls)