#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
"""Support functions for tests"""
import builtins
import contextlib
import cProfile
import io
import linecache
import logging
import os
import signal
import sys
import types
import unittest
import warnings
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager
from logging import Logger
from pstats import Stats
from types import FrameType, ModuleType
from typing import Any, TypeVar, Union
from unittest.mock import MagicMock
import numpy as np
from ax.exceptions.core import AxParameterWarning
from ax.utils.common.base import Base
from ax.utils.common.constants import TESTENV_ENV_KEY, TESTENV_ENV_VAL
from ax.utils.common.equality import object_attribute_dicts_find_unequal_fields
from ax.utils.common.logger import get_logger
from botorch.exceptions.warnings import InputDataWarning
from pyfakefs import fake_filesystem_unittest
T_AX_BASE_OR_ATTR_DICT = Union[Base, dict[str, Any]]
COMPARISON_STR_MAX_LEVEL = 8
T = TypeVar("T")
logger: Logger = get_logger(__name__)
def _get_tb_lines(tb: types.TracebackType) -> list[tuple[str, int, str]]:
"""Get the filename and line number and line contents of all the lines in the
traceback with the root at the top.
"""
res = []
opt_tb = tb
while opt_tb is not None:
lineno = opt_tb.tb_frame.f_lineno
filename = opt_tb.tb_frame.f_code.co_filename
line = linecache.getline(filename, lineno).strip()
res.append((filename, lineno, line))
opt_tb = opt_tb.tb_next
res.reverse()
return res
# pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
# parameter.
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext):
"""
Attributes:
lineno: the line number on which the error occured
filename: the file in which the error occured
"""
_expected_line: str | None
lineno: int | None
filename: str | None
def __init__(
self,
expected: type[Exception],
test_case: unittest.TestCase,
expected_line: str | None = None,
expected_regex: str | None = None,
) -> None:
self._expected_line = (
expected_line.strip() if expected_line is not None else None
)
self.lineno = None
self.filename = None
super().__init__(
expected=expected, test_case=test_case, expected_regex=expected_regex
)
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
def __exit__(
self,
exc_type: type[Exception] | None,
exc_value: Exception | None,
tb: types.TracebackType | None,
) -> bool:
"""This is called when the context closes. If an exception was raised
`exc_type`, `exc_value` and `tb` will be set.
"""
if not super().__exit__(exc_type, exc_value, tb):
return False # reraise
# super().__exit__ will throw if exc_type is None
assert exc_type is not None
assert exc_value is not None
assert tb is not None
frames = _get_tb_lines(tb)
self.filename, self.lineno, _ = frames[0]
lines = [line for _, _, line in frames]
if self._expected_line is not None and self._expected_line not in lines:
# pyre-ignore [16]: ... has no attribute `_raiseFailure`.
self._raiseFailure(
f"{self._expected_line!r} was not found in the traceback: {lines!r}"
)
return True
# Instead of showing a warning (like in the standard library) we throw an error when
# deprecated functions are called.
def _deprecate(original_func: Callable) -> Callable:
def _deprecated_func(*args: list[Any], **kwargs: dict[str, Any]) -> None:
raise RuntimeError(
f"This function is deprecated please use {original_func.__name__} "
"instead."
)
return _deprecated_func
def _build_comparison_str(
first: T_AX_BASE_OR_ATTR_DICT,
second: T_AX_BASE_OR_ATTR_DICT,
level: int = 0,
values_in_suffix: str = "",
skip_db_id_check: bool = False,
) -> str:
"""Recursively build a comparison string for classes that extend Ax `Base`
or two dictionaries (dictionaries are passed in in the recursive case).
Prints out an 'inequality report' that includes nested fields.
NOTE: Allows recursion only up to level 4, with markers like '1)', then 'a)',
then 'i)', then '*' for the nested lists.
For example, for two experiments, the 'report' might look like this, if their
search spaces are unequal because of difference of parameters in parameter
constraints:
Experiment(test_1) (type `Experiment`) != Experiment(test_2) (type `Experiment`).
Fields with different values:
1) _search_space: ... != ...
Fields with different values in 1):
a) _parameter_constraints: ... != ...
Fields with different values in a):
i) _parameter: ... != ...
Fields with different values in i):
* db_id: ... != ...
NOTE: If ``skip_db_id_check`` is ``True``, will exclude the ``db_id`` attributes
from the equality check. Useful for ensuring that all attributes of an object are
equal except the ids, with which one or both of them are saved to the database
(e.g. if confirming an object before it was saved, to the version reloaded
from the DB).
"""
def _unequal_str(first: Any, second: Any) -> str: # pyre-ignore[2]
return f"{first} (type {type(first)}) != {second} (type {type(second)})."
if first == second:
return ""
if level > COMPARISON_STR_MAX_LEVEL:
# Don't go deeper than 4 levels as the inequality report will not be legible.
return (
f"\n... also there were unequal fields at levels {level}+; "
"to see full comparison past this level, adjust `ax.utils.common.testutils."
"COMPARISON_STR_MAX_LEVEL`"
)
msg = ""
indent = " " * level * 4
unequal_types, unequal_val = object_attribute_dicts_find_unequal_fields(
one_dict=first.__dict__ if isinstance(first, Base) else first,
other_dict=second.__dict__ if isinstance(second, Base) else second,
fast_return=False,
skip_db_id_check=skip_db_id_check,
)
unequal_types_suffixed = {
f"{k} (field had values of unequal type)": v for k, v in unequal_types.items()
}
if level == 0:
msg += f"{_unequal_str(first=first, second=second)}\n"
msg += f"\n{indent}Fields with different values{values_in_suffix}:\n"
joint_unequal_field_dict = {**unequal_val, **unequal_types_suffixed}
for idx, (field, (first, second)) in enumerate(joint_unequal_field_dict.items()):
# For level 0, use numbers as bullets. For 1, use letters. For 2, use "i".
# For 3+, use "*".
bul = "*"
if level == 0:
bul = f"{idx + 1})"
elif level == 1:
bul = f"{chr(ord('a') + idx)})"
elif level == 2:
bul = f"{'i' * (idx + 1)})"
elif level <= COMPARISON_STR_MAX_LEVEL:
# Add default for when setting `COMPARISON_STR_MAX_LEVEL` to higher value
# during debugging.
bul = "*"
else:
raise RuntimeError(
"Reached level > `COMPARISON_STR_MAX_LEVEL`, which should've been "
"unreachable."
)
msg += f"\n{indent}{bul} {field}: {_unequal_str(first=first, second=second)}\n"
if isinstance(first, (dict, Base)) and isinstance(second, (dict, Base)):
msg += _build_comparison_str(
first=first,
second=second,
level=level + 1,
values_in_suffix=f" in {bul}",
skip_db_id_check=skip_db_id_check,
)
elif isinstance(first, list) and isinstance(second, list):
# To compare lists recursively via same function, making them into dicts
# with index keys.
msg += _build_comparison_str(
first=dict(zip([str(x) for x in range(len(first))], first)),
second=dict(zip([str(x) for x in range(len(second))], second)),
level=level + 1,
values_in_suffix=f" in {bul}",
skip_db_id_check=skip_db_id_check,
)
return msg
[docs]
def setup_import_mocks(
mocked_import_paths: list[str], mock_config_dict: dict[str, Any] | None = None
) -> None:
"""This function mocks expensive modules used in tests. It must be called before
those modules are imported or it will not work. Stubbing out these modules
will obviously affect the behavior of all tests that use it, so be sure modules
being mocked are not important to your test. It will also mock all child modules.
Args:
mocked_import_paths: List of module paths to mock.
mock_config_dict: Dictionary of attributes to mock on the modules being mocked.
This is useful if the import is expensive, but there is still some
functionality it has the test relies on. These attributes will be
set on all modules being mocked.
"""
# pyre-fixme[3]
def custom_import(name: str, *args: Any, **kwargs: Any) -> ModuleType:
for import_path in mocked_import_paths:
if name == import_path or name.startswith(f"{import_path}."):
mymock = MagicMock()
if mock_config_dict is not None:
mymock.configure_mock(**mock_config_dict)
return mymock
return original_import(name, *args, **kwargs)
for import_path in mocked_import_paths:
if import_path in sys.modules and not isinstance(
sys.modules[import_path], MagicMock
):
raise Exception(f"{import_path} has already been imported!")
# Replace the original import with the custom one
# pyre-fixme[61][53]
original_import: Callable[..., ModuleType] = builtins.__import__
# pyre-fixme[9]: __import__ has type `(name: str, globals: Optional[Mapping[str,
# object]] = ..., locals: Optional[Mapping[str, object]] = ..., fromlist:
# Sequence[str] = ..., level: int = ...) -> ModuleType`; used as `(name: str,
# *(Any), **(Any)) -> Any`.
builtins.__import__ = custom_import
[docs]
class TestCase(fake_filesystem_unittest.TestCase):
"""The base Ax test case, contains various helper functions to write unittests."""
MAX_TEST_SECONDS = 60
NUMBER_OF_PROFILER_LINES_TO_OUTPUT = 20
PROFILE_TESTS = False
_long_test_active_reason: str | None = None
def __init__(self, methodName: str = "runTest") -> None:
def signal_handler(signum: int, frame: FrameType | None) -> None:
message = f"Test took longer than {self.MAX_TEST_SECONDS} seconds."
if self.PROFILE_TESTS:
self._print_profiler_output()
else:
message += (
" To see a profiler output, set `TestCase.PROFILE_TESTS` to `True`."
)
if hasattr(sys, "gettrace") and sys.gettrace() is not None:
# If we're in a debugger session, let the test continue running.
return
elif self._long_test_active_reason is None:
message += (
" To specify a reason for a long running test,"
+ " utilize the @ax_long_test decorator. If your test "
+ "is long because it's doing modeling, please use the "
+ "@mock_botorch_optimize decorator and see if that helps."
)
raise TimeoutError(message)
else:
message += (
" Reason for long running test: " + self._long_test_active_reason
)
logger.warning(message)
super().__init__(methodName=methodName)
signal.signal(signal.SIGALRM, signal_handler)
# This is set to indicate we are running in a test environment. Code can check
# this to:
# * more strictly enforce SQL encoding
# (https://github.com/facebook/Ax/blob/main/ax/storage/sqa_store/save.py#L598)
# * avoid actions that will affect product environments
os.environ[TESTENV_ENV_KEY] = TESTENV_ENV_VAL
[docs]
def setUp(self) -> None:
"""
Only show log messages of WARNING or higher while testing.
Ax prints a lot of INFO logs that are not relevant for unit tests.
Also silences a number of common warnings originating from Ax & BoTorch.
"""
super().setUp()
self.profiler = cProfile.Profile()
if self.PROFILE_TESTS:
self.profiler.enable()
self.addCleanup(self.profiler.disable)
logger = get_logger(__name__, level=logging.WARNING)
# Parent handlers are shared, so setting the level this
# way applies it to all Ax loggers.
if logger.parent is not None and hasattr(logger.parent, "handlers"):
logger.parent.handlers[0].setLevel(logging.WARNING)
# Choice parameter default parameter type / is_ordered warnings.
warnings.filterwarnings(
"ignore",
message=".*is not specified for .ChoiceParameter.*",
category=AxParameterWarning,
)
# BoTorch float32 warning.
warnings.filterwarnings(
"ignore",
message="The model inputs are of type",
category=InputDataWarning,
)
# BoTorch input standardization warnings.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) is not standardized ",
category=InputDataWarning,
)
warnings.filterwarnings(
"ignore",
message=r"Data \(input features\) is not",
category=InputDataWarning,
)
def run(
self, result: unittest.result.TestResult | None = ...
) -> unittest.result.TestResult | None:
# Arrange for a SIGALRM signal to be delivered to the calling process
# in specified number of seconds.
signal.alarm(self.MAX_TEST_SECONDS)
try:
result = super().run(result)
finally:
signal.alarm(0)
return result
[docs]
def assertEqual(
self,
first: Any, # pyre-ignore[2]
second: Any, # pyre-ignore[2]
msg: str | None = None,
) -> None:
if isinstance(first, Base) and isinstance(second, Base):
self.assertAxBaseEqual(first=first, second=second, msg=msg)
else:
super().assertEqual(first=first, second=second, msg=msg)
[docs]
def assertAxBaseEqual(
self,
first: Base,
second: Base,
msg: str | None = None,
skip_db_id_check: bool = False,
) -> None:
"""Check that two Ax objects that subclass ``Base`` are equal or raise
assertion error otherwise.
Args:
first: ``Base``-subclassing object to compare to ``second``.
second: ``Base``-subclassing object to compare to ``first``.
msg: Message to put into the assertion error raised on inequality; if not
specified, a default message is used.
skip_db_id_check: If ``True``, will exclude the ``db_id`` attributes from
the equality check. Useful for ensuring that all attributes of an object
are equal except the ids, with which one or both of them are saved to
the database (e.g. if confirming an object before it was saved, to the
version reloaded from the DB).
"""
self.assertIsInstance(
first, Base, "First argument is not a subclass of Ax `Base`."
)
self.assertIsInstance(
second, Base, "Second argument is not a subclass of Ax `Base`."
)
if (
not first._eq_skip_db_id_check(other=second)
if skip_db_id_check
else first != second
):
raise self.failureException(
"Encountered unequal objects. "
"Attempting in-depth comparison; note that this recurs through the"
" attributes of the objects being compared multiple times!\n\n"
+ _build_comparison_str(
first=first, second=second, skip_db_id_check=skip_db_id_check
),
)
[docs]
def assertRaisesOn(
self,
exc: type[Exception],
line: str | None = None,
regex: str | None = None,
# pyre-ignore[24]: Generic type `AbstractContextManager`
# expects 2 type parameters, received 1.
) -> AbstractContextManager[None]:
"""Assert that an exception is raised on a specific line."""
context = _AssertRaisesContextOn(exc, self, line, regex)
return context.handle("assertRaisesOn", [], {})
[docs]
def assertDictsAlmostEqual(
self, a: dict[str, Any], b: dict[str, Any], consider_nans_equal: bool = False
) -> None:
"""Testing utility that checks that
1) the keys of `a` and `b` are identical, and that
2) the values of `a` and `b` are almost equal if they have a floating point
type, considering NaNs as equal, and otherwise just equal.
Args:
test: The test case object.
a: A dictionary.
b: Another dictionary.
consider_nans_equal: Whether to consider NaNs equal when comparing floating
point numbers.
"""
set_a = set(a.keys())
set_b = set(b.keys())
key_msg = (
"Dict keys differ."
f"Keys that are in a but not b: {set_a - set_b}."
f"Keys that are in b but not a: {set_b - set_a}."
)
self.assertEqual(set_a, set_b, msg=key_msg)
for field in b:
a_field = a[field]
b_field = b[field]
msg = f"Dict values differ for key {field}: {a[field]=}, {b[field]=}."
# for floating point values, compare approximately and consider NaNs equal
if isinstance(a_field, float):
if consider_nans_equal and np.isnan(a_field):
self.assertTrue(np.isnan(b_field), msg=msg)
else:
self.assertAlmostEqual(a_field, b_field, msg=msg)
else:
self.assertEqual(a_field, b_field, msg=msg)
[docs]
def assertIsSubDict(
self,
subdict: dict[str, Any],
superdict: dict[str, Any],
almost_equal: bool = False,
consider_nans_equal: bool = False,
) -> None:
"""Testing utility that checks that all keys and values of `subdict` are
contained in `dict`.
Args:
subdict: A smaller dictionary.
superdict: A larger dictionary which should contain all keys of subdict
and the same values as subdict for the corresponding keys.
"""
intersection_dict = {k: superdict[k] for k in subdict if k in superdict}
if consider_nans_equal and not almost_equal:
raise ValueError(
"`consider_nans_equal` can only be used with `almost_equal`"
)
if almost_equal:
self.assertDictsAlmostEqual(
subdict, intersection_dict, consider_nans_equal=consider_nans_equal
)
else:
self.assertEqual(subdict, intersection_dict)
[docs]
@staticmethod
@contextlib.contextmanager
def silence_stderr() -> Generator[None, None, None]:
"""A context manager that silences stderr for part of a test.
If any exception passes through this context manager the stderr will be printed,
otherwise it will be discarded.
"""
new_err = io.StringIO()
old_err = sys.stderr
try:
sys.stderr = new_err
yield
except Exception:
print(new_err.getvalue(), file=old_err, flush=True)
raise
finally:
sys.stderr = old_err
def _print_profiler_output(self) -> None:
"""Print profiler output to stdout."""
s = io.StringIO()
ps = Stats(self.profiler, stream=s).sort_stats("cumulative").reverse_order()
ps.print_stats()
output = s.getvalue().splitlines()
headers = output[:5]
# Print the headers
for line in headers:
print(line)
# Print the longest running functions
for line in output[-self.NUMBER_OF_PROFILER_LINES_TO_OUTPUT :]:
print(line)
@classmethod
@contextlib.contextmanager
def ax_long_test(cls, reason: str | None) -> Generator[None, None, None]:
cls._long_test_active_reason = reason
yield
cls._long_test_active_reason = None
# This list is taken from the python standard library
# pyre-fixme[4]: Attribute must be annotated.
failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual)
# pyre-fixme[4]: Attribute must be annotated.
failIfEqual = assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual)
# pyre-fixme[4]: Attribute must be annotated.
failUnlessAlmostEqual = assertAlmostEquals = _deprecate(
unittest.TestCase.assertAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
failIfAlmostEqual = assertNotAlmostEquals = _deprecate(
unittest.TestCase.assertNotAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
failUnless = assert_ = _deprecate(unittest.TestCase.assertTrue)
# pyre-fixme[4]: Attribute must be annotated.
failUnlessRaises = _deprecate(unittest.TestCase.assertRaises)
# pyre-fixme[4]: Attribute must be annotated.
failIf = _deprecate(unittest.TestCase.assertFalse)
# pyre-fixme[4]: Attribute must be annotated.
assertRaisesRegexp = _deprecate(unittest.TestCase.assertRaisesRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertRegexpMatches = _deprecate(unittest.TestCase.assertRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertNotRegexpMatches = _deprecate(unittest.TestCase.assertNotRegex)