#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
"""Support functions for tests
"""
import contextlib
import io
import linecache
import signal
import sys
import types
import unittest
from types import FrameType
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from ax.utils.common.base import Base
from ax.utils.common.equality import object_attribute_dicts_find_unequal_fields
T_AX_BASE_OR_ATTR_DICT = Union[Base, Dict[str, Any]]
def _get_tb_lines(tb: types.TracebackType) -> List[Tuple[str, int, str]]:
"""Get the filename and line number and line contents of all the lines in the
traceback with the root at the top.
"""
res = []
opt_tb = tb
while opt_tb is not None:
lineno = opt_tb.tb_frame.f_lineno
filename = opt_tb.tb_frame.f_code.co_filename
line = linecache.getline(filename, lineno).strip()
res.append((filename, lineno, line))
opt_tb = opt_tb.tb_next
res.reverse()
return res
# pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
# parameter.
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext):
"""
Attributes:
lineno: the line number on which the error occured
filename: the file in which the error occured
"""
_expected_line: Optional[str]
lineno: Optional[int]
filename: Optional[str]
def __init__(
self,
expected: Type[Exception],
test_case: unittest.TestCase,
expected_line: Optional[str] = None,
expected_regex: Optional[str] = None,
) -> None:
self._expected_line = (
expected_line.strip() if expected_line is not None else None
)
self.lineno = None
self.filename = None
# pyre-fixme[28]: Unexpected keyword argument `expected`.
super().__init__(
expected=expected, test_case=test_case, expected_regex=expected_regex
)
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[types.TracebackType],
) -> bool:
"""This is called when the context closes. If an exception was raised
`exc_type`, `exc_value` and `tb` will be set.
"""
if not super().__exit__(exc_type, exc_value, tb):
return False # reraise
# super().__exit__ will throw if exc_type is None
assert exc_type is not None
assert exc_value is not None
assert tb is not None
frames = _get_tb_lines(tb)
self.filename, self.lineno, _ = frames[0]
lines = [line for _, _, line in frames]
if self._expected_line is not None and self._expected_line not in lines:
# pyre-ignore [16]: ... has no attribute `_raiseFailure`.
self._raiseFailure(
f"{self._expected_line!r} was not found in the traceback: {lines!r}"
)
return True
# Instead of showing a warning (like in the standard library) we throw an error when
# deprecated functions are called.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def _deprecate(original_func: Callable) -> Callable:
def _deprecated_func(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
raise RuntimeError(
f"This function is deprecated please use {original_func.__name__} "
"instead."
)
return _deprecated_func
def _build_comparison_str(
first: T_AX_BASE_OR_ATTR_DICT,
second: T_AX_BASE_OR_ATTR_DICT,
level: int = 0,
values_in_suffix: str = "",
) -> str:
"""Recursively build a comparison string for classes that extend Ax `Base`
or two dictionaries (dictionaries are passed in in the recursive case).
Prints out an 'inequality report' that includes nested fields.
NOTE: Allows recursion only up to level 4, with markers like '1)', then 'a)',
then 'i)', then '*' for the nested lists.
For example, for two experiments, the 'report' might look like this, if their
search spaces are unequal because of difference of parameters in parameter
constraints:
Experiment(test_1) (type `Experiment`) != Experiment(test_2) (type `Experiment`).
Fields with different values:
1) _search_space: ... != ...
Fields with different values in 1):
a) _parameter_constraints: ... != ...
Fields with different values in a):
i) _parameter: ... != ...
Fields with different values in i):
* db_id: ... != ...
"""
def _unequal_str(first: Any, second: Any) -> str: # pyre-ignore[2]
return f"{first} (type {type(first)}) != {second} (type {type(second)})."
if first == second or level > 3:
# Don't go deeper than 4 levels as the inequality report will not be legible.
return ""
msg = ""
indent = " " * level * 4
_, unequal_val = object_attribute_dicts_find_unequal_fields(
one_dict=first.__dict__ if isinstance(first, Base) else first,
other_dict=second.__dict__ if isinstance(second, Base) else second,
fast_return=False,
)
if level == 0:
msg += f"{_unequal_str(first=first, second=second)}\n"
msg += f"\n{indent}Fields with different values{values_in_suffix}:\n"
for idx, (field, (first, second)) in enumerate(unequal_val.items()):
# For level 0, use numbers as bullets. For 1, use letters. For 2, use "i".
# For 3, use "*".
bul = "*"
if level == 0:
bul = f"{idx + 1})"
if level == 1:
bul = f"{chr(ord('a') + idx)})"
if level == 2:
bul = f"{'i' * (idx + 1)})"
msg += f"\n{indent}{bul} {field}: {_unequal_str(first=first, second=second)}\n"
if isinstance(first, (dict, Base)) and isinstance(second, (dict, Base)):
msg += _build_comparison_str(
first=first,
second=second,
level=level + 1,
values_in_suffix=f" in {bul}",
)
elif isinstance(first, list) and isinstance(second, list):
# To compare lists recursively via same function, making them into dicts
# with index keys.
msg += _build_comparison_str(
first=dict(zip([str(x) for x in range(len(first))], first)),
second=dict(zip([str(x) for x in range(len(second))], second)),
level=level + 1,
values_in_suffix=f" in {bul}",
)
return msg
[docs]class TestCase(unittest.TestCase):
"""The base Ax test case, contains various helper functions to write unittests."""
MAX_TEST_SECONDS = 540
def __init__(self, methodName: str = "runTest") -> None:
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
raise Exception(f"Test timed out at {self.MAX_TEST_SECONDS} seconds")
super().__init__(methodName=methodName)
signal.signal(signal.SIGALRM, signal_handler)
def run(
self, result: Optional[unittest.result.TestResult] = ...
) -> Optional[unittest.result.TestResult]:
# Arrange for a SIGALRM signal to be delivered to the calling process
# in specified number of seconds.
signal.alarm(self.MAX_TEST_SECONDS)
try:
result = super().run(result)
finally:
signal.alarm(0)
return result
[docs] def assertEqual(
self,
first: Any, # pyre-ignore[2]
second: Any, # pyre-ignore[2]
msg: Optional[str] = None,
) -> None:
if isinstance(first, Base) and isinstance(second, Base):
self.assertAxBaseEqual(first=first, second=second, msg=msg)
else:
super().assertEqual(first=first, second=second, msg=msg)
def assertAxBaseEqual(
self, first: Base, second: Base, msg: Optional[str] = None
) -> None:
self.assertIsInstance(
first, Base, "First argument is not a subclass of Ax `Base`."
)
self.assertIsInstance(
second, Base, "Second argument is not a subclass of Ax `Base`."
)
if first != second:
raise self.failureException(
_build_comparison_str(first=first, second=second)
)
[docs] def assertRaisesOn(
self,
exc: Type[Exception],
line: Optional[str] = None,
regex: Optional[str] = None,
) -> ContextManager[None]:
"""Assert that an exception is raised on a specific line."""
context = _AssertRaisesContextOn(exc, self, line, regex)
# pyre-ignore [16]: ... has no attribute `handle`.
return context.handle("assertRaisesOn", [], {})
[docs] @staticmethod
@contextlib.contextmanager
def silence_stderr() -> Generator[None, None, None]:
"""A context manager that silences stderr for part of a test.
If any exception passes through this context manager the stderr will be printed,
otherwise it will be discarded.
"""
new_err = io.StringIO()
old_err = sys.stderr
try:
sys.stderr = new_err
yield
except Exception:
print(new_err.getvalue(), file=old_err, flush=True)
raise
finally:
sys.stderr = old_err
# This list is taken from the python standard library
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failIfEqual = assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnlessAlmostEqual = assertAlmostEquals = _deprecate(
unittest.TestCase.assertAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failIfAlmostEqual = assertNotAlmostEquals = _deprecate(
unittest.TestCase.assertNotAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnless = assert_ = _deprecate(unittest.TestCase.assertTrue)
# pyre-fixme[4]: Attribute must be annotated.
failUnlessRaises = _deprecate(unittest.TestCase.assertRaises)
# pyre-fixme[4]: Attribute must be annotated.
failIf = _deprecate(unittest.TestCase.assertFalse)
# pyre-fixme[4]: Attribute must be annotated.
assertRaisesRegexp = _deprecate(unittest.TestCase.assertRaisesRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertRegexpMatches = _deprecate(unittest.TestCase.assertRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertNotRegexpMatches = _deprecate(unittest.TestCase.assertNotRegex)