Source code for ax.utils.common.decorator
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, TypeVar
T = TypeVar("T")
[docs]
class ClassDecorator(ABC):
"""
Template for making a decorator work as a class level decorator. That decorator
should extend `ClassDecorator`. It must implement `__init__` and
`decorate_callable`. See `disable_logger.decorate_callable` for an example.
`decorate_callable` should call `self._call_func()` instead of directly calling
`func` to handle static functions.
Note: `_call_func` is still imperfect and unit tests should be used to ensure
everything is working properly. There is a lot of complexity in detecting
classmethods and staticmethods and removing the self argument in the right
situations. For best results always use keyword args in the decorated class.
`DECORATE_PRIVATE` can be set to determine whether private methods should be
decorated. In the case of a logging decorator, you may only want to decorate things
the user calls. But in the case of a disable logging decorator, you may want to
decorate everything to ensure no logs escape.
"""
DECORATE_PRIVATE = True
[docs]
def decorate_class(self, klass: T) -> T:
for attr in dir(klass):
if not self.DECORATE_PRIVATE and attr[0] == "_":
continue
attr_value = getattr(klass, attr)
if (
not callable(attr_value)
or isinstance(attr_value, type)
or attr
in (
"__subclasshook__",
"__class__",
"__repr__",
"__str__",
"__getattribute__",
"__new__",
"__call__",
"__eq__",
"_call_func",
)
):
continue
setattr(klass, attr, self.decorate_callable(attr_value))
return klass
[docs]
@abstractmethod
def decorate_callable(self, func: Callable[..., T]) -> Callable[..., T]:
pass
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
if isinstance(func, type):
return self.decorate_class(func)
return self.decorate_callable(func)
def _call_func(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
try:
return func(*args, **kwargs)
except TypeError as e:
# static functions
try:
return func(*args[1:], **kwargs)
except TypeError:
# it wasn't that it was a static function
raise e