#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
"""Support functions for tests
"""
import contextlib
import io
import linecache
import sys
import types
import unittest
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
)
from ax.utils.common.equality import Base, object_attribute_dicts_find_unequal_fields
def _get_tb_lines(tb: types.TracebackType) -> List[Tuple[str, int, str]]:
"""Get the filename and line number and line contents of all the lines in the
traceback with the root at the top.
"""
res = []
opt_tb = tb
while opt_tb is not None:
lineno = opt_tb.tb_frame.f_lineno
filename = opt_tb.tb_frame.f_code.co_filename
line = linecache.getline(filename, lineno).strip()
res.append((filename, lineno, line))
opt_tb = opt_tb.tb_next
res.reverse()
return res
# pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
# parameter.
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext):
"""
Attributes:
lineno: the line number on which the error occured
filename: the file in which the error occured
"""
_expected_line: Optional[str]
lineno: Optional[int]
filename: Optional[str]
def __init__(
self,
expected: Type[Exception],
test_case: unittest.TestCase,
expected_line: Optional[str] = None,
expected_regex: Optional[str] = None,
) -> None:
self._expected_line = (
expected_line.strip() if expected_line is not None else None
)
self.lineno = None
self.filename = None
# pyre-fixme[28]: Unexpected keyword argument `expected`.
super().__init__(
expected=expected, test_case=test_case, expected_regex=expected_regex
)
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
# inconsistently.
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[types.TracebackType],
) -> bool:
"""This is called when the context closes. If an exception was raised
`exc_type`, `exc_value` and `tb` will be set.
"""
if not super().__exit__(exc_type, exc_value, tb):
return False # reraise
# super().__exit__ will throw if exc_type is None
assert exc_type is not None
assert exc_value is not None
assert tb is not None
frames = _get_tb_lines(tb)
self.filename, self.lineno, _ = frames[0]
lines = [line for _, _, line in frames]
if self._expected_line is not None and self._expected_line not in lines:
# pyre-ignore [16]: ... has no attribute `_raiseFailure`.
self._raiseFailure(
f"{self._expected_line!r} was not found in the traceback: {lines!r}"
)
return True
# Instead of showing a warning (like in the standard library) we throw an error when
# deprecated functions are called.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def _deprecate(original_func: Callable) -> Callable:
def _deprecated_func(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
raise RuntimeError(
f"This function is deprecated please use {original_func.__name__} "
"instead."
)
return _deprecated_func
[docs]class TestCase(unittest.TestCase):
"""The base test case for Ax, contains various helper functions to write unittest.
"""
def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName=methodName)
[docs] def assertEqual(
self,
first: Any, # pyre-ignore[2]
second: Any, # pyre-ignore[2]
msg: Optional[str] = None,
) -> None:
if isinstance(first, Base) and isinstance(second, Base):
self.assertAxBaseEqual(first=first, second=second, msg=msg)
else:
super().assertEqual(first=first, second=second, msg=msg)
def assertAxBaseEqual(
self, first: Base, second: Base, msg: Optional[str] = None
) -> None:
self.assertIsInstance(
first, Base, "First argument is not a subclass of Ax `Base`."
)
self.assertIsInstance(
second, Base, "Second argument is not a subclass of Ax `Base`."
)
if first != second:
_, unequal_val = object_attribute_dicts_find_unequal_fields(
one_dict=first.__dict__, other_dict=second.__dict__, fast_return=False
)
msg = msg or ""
msg += f"{first} (type {type(first)}) != {second} (type {type(second)})."
if unequal_val:
msg += "\n\nFields with different values:"
for field, (one_val, other_val) in unequal_val.items():
msg += f"\n* {field}: {one_val} (type {type(one_val)}) VS. "
msg += f"{other_val} (type {type(other_val)})"
raise self.failureException(msg)
[docs] def assertRaisesOn(
self,
exc: Type[Exception],
line: Optional[str] = None,
regex: Optional[str] = None,
) -> ContextManager[None]:
"""Assert that an exception is raised on a specific line.
"""
context = _AssertRaisesContextOn(exc, self, line, regex)
# pyre-ignore [16]: ... has no attribute `handle`.
return context.handle("assertRaisesOn", [], {})
[docs] @staticmethod
@contextlib.contextmanager
def silence_stderr() -> Generator[None, None, None]:
"""A context manager that silences stderr for part of a test.
If any exception passes through this context manager the stderr will be printed,
otherwise it will be discarded.
"""
new_err = io.StringIO()
old_err = sys.stderr
try:
sys.stderr = new_err
yield
except Exception:
print(new_err.getvalue(), file=old_err, flush=True)
raise
finally:
sys.stderr = old_err
# This list is taken from the python standard library
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failIfEqual = assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnlessAlmostEqual = assertAlmostEquals = _deprecate(
unittest.TestCase.assertAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failIfAlmostEqual = assertNotAlmostEquals = _deprecate(
unittest.TestCase.assertNotAlmostEqual
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-fixme[4]: Attribute must be annotated.
failUnless = assert_ = _deprecate(unittest.TestCase.assertTrue)
# pyre-fixme[4]: Attribute must be annotated.
failUnlessRaises = _deprecate(unittest.TestCase.assertRaises)
# pyre-fixme[4]: Attribute must be annotated.
failIf = _deprecate(unittest.TestCase.assertFalse)
# pyre-fixme[4]: Attribute must be annotated.
assertRaisesRegexp = _deprecate(unittest.TestCase.assertRaisesRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertRegexpMatches = _deprecate(unittest.TestCase.assertRegex)
# pyre-fixme[4]: Attribute must be annotated.
assertNotRegexpMatches = _deprecate(unittest.TestCase.assertNotRegex)