Source code for ax.utils.common.func_enum
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from enum import Enum, unique
from importlib import import_module
from typing import Any, Callable
from ax.exceptions.core import UnsupportedError
[docs]
@unique
class FuncEnum(Enum):
"""A base class for all enums with the following structure: string values that
map to names of functions, which reside in the same module as the enum."""
# pyre-ignore[3]: Input constructors will be used to make different inputs,
# so we need to allow `Any` return type here.
def __call__(self, **kwargs: Any) -> Any:
"""Defines a method, by which the members of this enum can be called,
e.g. ``MyFunctions.F(**kwargs)``, which will call the corresponding
function registered by the name ``F`` in the enum."""
return self._get_function_for_value()(**kwargs)
# pyre-ignore[31]: Expression `typing.Callable[([...], typing.Any)]`
# is not a valid type.
def _get_function_for_value(self) -> Callable[[...], Any]:
"""Retrieve the function in this module, name of which corresponds to the
value of the enum member."""
try:
return getattr(import_module(self.__module__), self.value)
except AttributeError:
raise UnsupportedError(
f"{self.value} is not defined as a method in "
f"`{self.__module__}`. Please add the method "
"to the file."
)