Source code for ax.utils.common.testutils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""Support functions for tests
"""

import contextlib
import io
import linecache
import os
import signal
import subprocess
import sys
import types
import unittest
from functools import wraps
from types import FrameType
from typing import (
    Any,
    Callable,
    ContextManager,
    Dict,
    Generator,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
)
from unittest.mock import MagicMock

import yappi

from ax.utils.common.base import Base
from ax.utils.common.equality import object_attribute_dicts_find_unequal_fields


T_AX_BASE_OR_ATTR_DICT = Union[Base, Dict[str, Any]]
T = TypeVar("T")


def _get_tb_lines(tb: types.TracebackType) -> List[Tuple[str, int, str]]:
    """Get the filename and line number and line contents of all the lines in the
    traceback with the root at the top.
    """
    res = []
    opt_tb = tb
    while opt_tb is not None:
        lineno = opt_tb.tb_frame.f_lineno
        filename = opt_tb.tb_frame.f_code.co_filename
        line = linecache.getline(filename, lineno).strip()
        res.append((filename, lineno, line))
        opt_tb = opt_tb.tb_next
    res.reverse()
    return res


# pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
#  parameter.
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext):
    """
    Attributes:
       lineno: the line number on which the error occured
       filename: the file in which the error occured
    """

    _expected_line: Optional[str]
    lineno: Optional[int]
    filename: Optional[str]

    def __init__(
        self,
        expected: Type[Exception],
        test_case: unittest.TestCase,
        expected_line: Optional[str] = None,
        expected_regex: Optional[str] = None,
    ) -> None:
        self._expected_line = (
            expected_line.strip() if expected_line is not None else None
        )
        self.lineno = None
        self.filename = None
        # pyre-fixme[28]: Unexpected keyword argument `expected`.
        super().__init__(
            expected=expected, test_case=test_case, expected_regex=expected_regex
        )

    # pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
    #  inconsistently.
    # pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
    #  inconsistently.
    # pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
    #  inconsistently.
    def __exit__(
        self,
        exc_type: Optional[Type[Exception]],
        exc_value: Optional[Exception],
        tb: Optional[types.TracebackType],
    ) -> bool:
        """This is called when the context closes. If an exception was raised
        `exc_type`, `exc_value` and `tb` will be set.
        """
        if not super().__exit__(exc_type, exc_value, tb):
            return False  # reraise
        # super().__exit__ will throw if exc_type is None
        assert exc_type is not None
        assert exc_value is not None
        assert tb is not None
        frames = _get_tb_lines(tb)
        self.filename, self.lineno, _ = frames[0]
        lines = [line for _, _, line in frames]
        if self._expected_line is not None and self._expected_line not in lines:
            # pyre-ignore [16]: ... has no attribute `_raiseFailure`.
            self._raiseFailure(
                f"{self._expected_line!r} was not found in the traceback: {lines!r}"
            )

        return True


# Instead of showing a warning (like in the standard library) we throw an error when
# deprecated functions are called.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def _deprecate(original_func: Callable) -> Callable:
    def _deprecated_func(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
        raise RuntimeError(
            f"This function is deprecated please use {original_func.__name__} "
            "instead."
        )

    return _deprecated_func


def _build_comparison_str(
    first: T_AX_BASE_OR_ATTR_DICT,
    second: T_AX_BASE_OR_ATTR_DICT,
    level: int = 0,
    values_in_suffix: str = "",
) -> str:
    """Recursively build a comparison string for classes that extend Ax `Base`
    or two dictionaries (dictionaries are passed in in the recursive case).
    Prints out an 'inequality report' that includes nested fields.

    NOTE: Allows recursion only up to level 4, with markers like '1)', then 'a)',
    then 'i)', then '*' for the nested lists.

    For example, for two experiments, the 'report' might look like this, if their
    search spaces are unequal because of difference of parameters in parameter
    constraints:

    Experiment(test_1) (type `Experiment`) != Experiment(test_2) (type `Experiment`).

    Fields with different values:

    1) _search_space: ... != ...

        Fields with different values in 1):

            a) _parameter_constraints: ... != ...

                Fields with different values in a):

                    i) _parameter: ... != ...

                        Fields with different values in i):

                            * db_id: ... != ...
    """

    def _unequal_str(first: Any, second: Any) -> str:  # pyre-ignore[2]
        return f"{first} (type {type(first)}) != {second} (type {type(second)})."

    if first == second or level > 3:
        # Don't go deeper than 4 levels as the inequality report will not be legible.
        return ""

    msg = ""
    indent = " " * level * 4
    _, unequal_val = object_attribute_dicts_find_unequal_fields(
        one_dict=first.__dict__ if isinstance(first, Base) else first,
        other_dict=second.__dict__ if isinstance(second, Base) else second,
        fast_return=False,
    )
    if level == 0:
        msg += f"{_unequal_str(first=first, second=second)}\n"

    msg += f"\n{indent}Fields with different values{values_in_suffix}:\n"
    for idx, (field, (first, second)) in enumerate(unequal_val.items()):
        # For level 0, use numbers as bullets. For 1, use letters. For 2, use "i".
        # For 3, use "*".
        bul = "*"
        if level == 0:
            bul = f"{idx + 1})"
        if level == 1:
            bul = f"{chr(ord('a') + idx)})"
        if level == 2:
            bul = f"{'i' * (idx + 1)})"
        msg += f"\n{indent}{bul} {field}: {_unequal_str(first=first, second=second)}\n"
        if isinstance(first, (dict, Base)) and isinstance(second, (dict, Base)):
            msg += _build_comparison_str(
                first=first,
                second=second,
                level=level + 1,
                values_in_suffix=f" in {bul}",
            )
        elif isinstance(first, list) and isinstance(second, list):
            # To compare lists recursively via same function, making them into dicts
            # with index keys.
            msg += _build_comparison_str(
                first=dict(zip([str(x) for x in range(len(first))], first)),
                second=dict(zip([str(x) for x in range(len(second))], second)),
                level=level + 1,
                values_in_suffix=f" in {bul}",
            )
    return msg


def setup_import_mocks(mocked_import_paths: List[str]) -> None:
    for import_path in mocked_import_paths:
        sys.modules[import_path] = MagicMock()


[docs]class TestCase(unittest.TestCase): """The base Ax test case, contains various helper functions to write unittests.""" # try to remove these files on tearDown FILES_TO_CLEAN: List[str] = [] MAX_TEST_SECONDS = 5400000 MIN_TTOT = 1.0 PROFILER_COLUMNS = { 0: ("name", 100), # default 36 is often not enough 1: ("ncall", 5), 2: ("tsub", 8), 3: ("ttot", 8), 4: ("tavg", 8), } # some test cases have a conflict with the profiler(P511247687) # it appears to be pytorch related https://fburl.com/wiki/r8u9f3rs # set to `False` on the specific testcase class if this occurs CAN_PROFILE = True _prior_status: Optional[str] = None def __init__(self, methodName: str = "runTest") -> None: def signal_handler(signum: int, frame: Optional[FrameType]) -> None: if self.CAN_PROFILE: yappi.get_func_stats( filter_callback=lambda s: s.ttot > self.MIN_TTOT ).sort(sort_type="ttot", sort_order="asc").print_all( columns=self.PROFILER_COLUMNS ) raise Exception(f"Test timed out at {self.MAX_TEST_SECONDS} seconds") super().__init__(methodName=methodName) signal.signal(signal.SIGALRM, signal_handler)
[docs] def tearDown(self) -> None: for f in self.FILES_TO_CLEAN: if os.path.exists(f): os.remove(f)
def run( self, result: Optional[unittest.result.TestResult] = ... ) -> Optional[unittest.result.TestResult]: # Arrange for a SIGALRM signal to be delivered to the calling process # in specified number of seconds. signal.alarm(self.MAX_TEST_SECONDS) self._prior_status = self._get_repository_status() try: if self.CAN_PROFILE: yappi.set_clock_type("wall") yappi.start() result = super().run(result) finally: if self.CAN_PROFILE: yappi.stop() signal.alarm(0) self._assert_status_is_unchanged() return result def _get_repository_status(self) -> str: return subprocess.run( ["hg", "status"], capture_output=True, ).stdout.decode("utf-8") def _assert_status_is_unchanged(self) -> None: post_status = self._get_repository_status() self.assertEqual( self._prior_status, post_status, "Files in the repository were modified while this test was running", )
[docs] def assertEqual( self, first: Any, # pyre-ignore[2] second: Any, # pyre-ignore[2] msg: Optional[str] = None, ) -> None: if isinstance(first, Base) and isinstance(second, Base): self.assertAxBaseEqual(first=first, second=second, msg=msg) else: super().assertEqual(first=first, second=second, msg=msg)
def assertAxBaseEqual( self, first: Base, second: Base, msg: Optional[str] = None ) -> None: self.assertIsInstance( first, Base, "First argument is not a subclass of Ax `Base`." ) self.assertIsInstance( second, Base, "Second argument is not a subclass of Ax `Base`." ) if first != second: raise self.failureException( _build_comparison_str(first=first, second=second) )
[docs] def assertRaisesOn( self, exc: Type[Exception], line: Optional[str] = None, regex: Optional[str] = None, ) -> ContextManager[None]: """Assert that an exception is raised on a specific line.""" context = _AssertRaisesContextOn(exc, self, line, regex) # pyre-ignore [16]: ... has no attribute `handle`. return context.handle("assertRaisesOn", [], {})
[docs] @staticmethod @contextlib.contextmanager def silence_stderr() -> Generator[None, None, None]: """A context manager that silences stderr for part of a test. If any exception passes through this context manager the stderr will be printed, otherwise it will be discarded. """ new_err = io.StringIO() old_err = sys.stderr try: sys.stderr = new_err yield except Exception: print(new_err.getvalue(), file=old_err, flush=True) raise finally: sys.stderr = old_err
# This list is taken from the python standard library # pyre-fixme[4]: Attribute must be annotated. # pyre-fixme[4]: Attribute must be annotated. failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual) # pyre-fixme[4]: Attribute must be annotated. # pyre-fixme[4]: Attribute must be annotated. failIfEqual = assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual) # pyre-fixme[4]: Attribute must be annotated. # pyre-fixme[4]: Attribute must be annotated. failUnlessAlmostEqual = assertAlmostEquals = _deprecate( unittest.TestCase.assertAlmostEqual ) # pyre-fixme[4]: Attribute must be annotated. # pyre-fixme[4]: Attribute must be annotated. failIfAlmostEqual = assertNotAlmostEquals = _deprecate( unittest.TestCase.assertNotAlmostEqual ) # pyre-fixme[4]: Attribute must be annotated. # pyre-fixme[4]: Attribute must be annotated. failUnless = assert_ = _deprecate(unittest.TestCase.assertTrue) # pyre-fixme[4]: Attribute must be annotated. failUnlessRaises = _deprecate(unittest.TestCase.assertRaises) # pyre-fixme[4]: Attribute must be annotated. failIf = _deprecate(unittest.TestCase.assertFalse) # pyre-fixme[4]: Attribute must be annotated. assertRaisesRegexp = _deprecate(unittest.TestCase.assertRaisesRegex) # pyre-fixme[4]: Attribute must be annotated. assertRegexpMatches = _deprecate(unittest.TestCase.assertRegex) # pyre-fixme[4]: Attribute must be annotated. assertNotRegexpMatches = _deprecate(unittest.TestCase.assertNotRegex)
[docs]def disable_profiler(f: Callable[..., T]) -> Callable[..., T]: """Not all tests are compatible with yappi. This wrapper can be used on a test to stop profiling instead of disabling it for the entire test case""" @wraps(f) def inner(self: TestCase, *args: Any, **kwargs: Any) -> T: yappi.stop() return f(self, *args, **kwargs) return inner