Source code for ax.utils.common.executils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import asyncio
import functools
import time
from contextlib import contextmanager
from logging import Logger
from typing import Any, Generator, List, Optional, Tuple, Type


MAX_WAIT_SECONDS: int = 600


# pyre-fixme[3]: Return annotation cannot be `Any`.
[docs]def retry_on_exception( exception_types: Optional[Tuple[Type[Exception], ...]] = None, no_retry_on_exception_types: Optional[Tuple[Type[Exception], ...]] = None, check_message_contains: Optional[List[str]] = None, retries: int = 3, suppress_all_errors: bool = False, logger: Optional[Logger] = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. default_return_on_suppression: Optional[Any] = None, wrap_error_message_in: Optional[str] = None, initial_wait_seconds: Optional[int] = None, ) -> Optional[Any]: """ A decorator for instance methods or standalone functions that makes them retry on failure and allows to specify on which types of exceptions the function should and should not retry. NOTE: If the argument `suppress_all_errors` is supplied and set to True, the error will be suppressed and default value returned. Args: exception_types: A tuple of exception(s) types to catch in the decorated function. If none is provided, baseclass Exception will be used. no_retry_on_exception_types: Exception types to consider non-retryable even if their supertype appears in `exception_types` or the only exceptions to not retry on if no `exception_types` are specified. check_message_contains: A list of strings, against which to match error messages. If the error message contains any one of these strings, the exception will cause a retry. NOTE: This argument works in addition to `exception_types`; if those are specified, only the specified types of exceptions will be caught and retried on if they contain the strings provided as `check_message_contains`. retries: Number of retries to perform. suppress_all_errors: If true, after all the retries are exhausted, the error will still be suppressed and `default_return_on_suppresion` will be returned from the function. NOTE: If using this argument, the decorated function may not actually get fully executed, if it consistently raises an exception. logger: A handle for the logger to be used. default_return_on_suppression: If the error is suppressed after all the retries, then this default value will be returned from the function. Defaults to None. wrap_error_message_in: If raising the error message after all the retries, a string wrapper for the error message (useful for making error messages more user-friendly). NOTE: Format of resulting error will be: "<wrap_error_message_in>: <original_error_type>: <original_error_msg>", with the stack trace of the original message. initial_wait_seconds: Initial length of time to wait between failures, doubled after each failure up to a maximum of 10 minutes. If unspecified then there is no wait between retries. """ # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def func_wrapper(func): # Depending on whether `func` is async or not, we use a slightly different # wrapper; if wrapping an async function, decorator will await it. # `async_actual_wrapper` and `actual_wrapper` are almost exactly the same, # except that the former is async and awaits the wrapped function. if asyncio.iscoroutinefunction(func): @functools.wraps(func) # pyre-fixme[53]: Captured variable `func` is not annotated. # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. async def async_actual_wrapper(*args, **kwargs): ( retry_exceptions, no_retry_exceptions, suppress_errors, ) = _validate_and_fill_defaults( retry_on_exception_types=exception_types, no_retry_on_exception_types=no_retry_on_exception_types, suppress_errors=suppress_all_errors, **kwargs, ) for i in range(retries): with handle_exceptions_in_retries( no_retry_exceptions=no_retry_exceptions, retry_exceptions=retry_exceptions, suppress_errors=suppress_errors, # pyre-fixme[6]: For 4th param expected `Optional[str]` but # got `Optional[List[str]]`. check_message_contains=check_message_contains, last_retry=i >= retries - 1, logger=logger, wrap_error_message_in=wrap_error_message_in, ): if i > 0 and initial_wait_seconds is not None: wait_interval = min( MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1) ) asyncio.sleep(wait_interval) return await func(*args, **kwargs) # If we are here, it means the retries were finished but # The error was suppressed. Hence return the default value provided. return default_return_on_suppression return async_actual_wrapper @functools.wraps(func) # pyre-fixme[53]: Captured variable `func` is not annotated. # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def actual_wrapper(*args, **kwargs): ( retry_exceptions, no_retry_exceptions, suppress_errors, ) = _validate_and_fill_defaults( retry_on_exception_types=exception_types, no_retry_on_exception_types=no_retry_on_exception_types, suppress_errors=suppress_all_errors, **kwargs, ) for i in range(retries): with handle_exceptions_in_retries( no_retry_exceptions=no_retry_exceptions, retry_exceptions=retry_exceptions, suppress_errors=suppress_errors, # pyre-fixme[6]: For 4th param expected `Optional[str]` but got # `Optional[List[str]]`. check_message_contains=check_message_contains, last_retry=i >= retries - 1, logger=logger, wrap_error_message_in=wrap_error_message_in, ): if i > 0 and initial_wait_seconds is not None: wait_interval = min( MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1) ) time.sleep(wait_interval) return func(*args, **kwargs) # If we are here, it means the retries were finished but # The error was suppressed. Hence return the default value provided. return default_return_on_suppression return actual_wrapper return func_wrapper
[docs]@contextmanager def handle_exceptions_in_retries( no_retry_exceptions: Tuple[Type[Exception], ...], retry_exceptions: Tuple[Type[Exception], ...], suppress_errors: bool, check_message_contains: Optional[str], last_retry: bool, logger: Optional[Logger], wrap_error_message_in: Optional[str], ) -> Generator[None, None, None]: try: yield # Perform action within the context manager. except no_retry_exceptions: raise except retry_exceptions as err: # Exceptions is a tuple. err_msg = getattr(err, "message", repr(err)) if not last_retry or suppress_errors: # We are either explicitly asked to suppress the error # or we have retries left. if logger is not None: # `logger.exception` automatically logs `err` and its stack trace. logger.exception(err) elif ( not last_retry and check_message_contains is not None and any(message in err_msg for message in check_message_contains) ): # In this case, the error is just logged, suppressed and default # value returned if logger is not None: logger.exception(wrap_error_message_in) elif not wrap_error_message_in: raise else: msg = f"{wrap_error_message_in}: {type(err).__name__}: {str(err)}" raise type(err)(msg).with_traceback(err.__traceback__)
def _validate_and_fill_defaults( retry_on_exception_types: Optional[Tuple[Type[Exception], ...]], no_retry_on_exception_types: Optional[Tuple[Type[Exception], ...]], suppress_errors: bool, **kwargs: Any, ) -> Tuple[Tuple[Type[Exception], ...], Tuple[Type[Exception], ...], bool]: if retry_on_exception_types is None: # If no exception type provided, we catch all errors. retry_on_exception_types = (Exception,) if not isinstance(retry_on_exception_types, tuple): raise ValueError("Expected a tuple of exception types.") if no_retry_on_exception_types is not None: if not isinstance(no_retry_on_exception_types, tuple): raise ValueError("Expected a tuple of non-retriable exception types.") if set(no_retry_on_exception_types).intersection(set(retry_on_exception_types)): raise ValueError( "Same exception type cannot appear in both " "`exception_types` and `no_retry_on_exception_types`." ) # `suppress_all_errors` could be a flag to the underlying function # when used on instance methods. suppress_errors = suppress_errors or kwargs.get("suppress_all_errors", False) return retry_on_exception_types, no_retry_on_exception_types or (), suppress_errors