Source code for ax.utils.testing.unittest_conventions

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import importlib
import pathlib
import sys
import unittest

import __test_modules__
from ax.utils.common import testutils


[docs]def get_all_subclasses(cls): """Reccursively get all the subclasses of cls""" for x in cls.__subclasses__(): # subclasses only contains direct decendants yield x yield from get_all_subclasses(x)
[docs]class TestUnittestConventions(testutils.TestCase):
[docs] def test_uses_ae_unittest(self): """Check that all of our tests are inheriting from our own base class Our base class does a bit more (like making sure we don't use any of python's deprecated `assert` functions) so we want to enforce its usage everywhere. """ test_modules = set(__test_modules__.TEST_MODULES) # Make sure everything is loaded for m in test_modules: importlib.import_module(m) test_cases = [ cls for cls in get_all_subclasses(unittest.TestCase) if cls.__module__ in test_modules ] base = testutils.TestCase for t in test_cases: with self.subTest(t.__name__): if not issubclass(t, base): abs_path = pathlib.Path(sys.modules[t.__module__].__file__) root = pathlib.Path(__test_modules__.__file__).parent filename = abs_path.relative_to(root) self.fail( f"in {filename}: {t.__qualname__} should inherit from " f"{base.__module__}.{base.__name__}" )