diff --git a/ax/utils/common/tests/test_testutils.py b/ax/utils/common/tests/test_testutils.py index 4a2572b0af9..390b59d3e49 100644 --- a/ax/utils/common/tests/test_testutils.py +++ b/ax/utils/common/tests/test_testutils.py @@ -102,3 +102,14 @@ def test_fail_deprecated(self) -> None: self.assertEqual(1, 1) with self.assertRaises(RuntimeError): self.assertEquals(1, 1) + + def test_ax_long_test_decorator(self) -> None: + testReason: str = "testReason" + + @TestCase.ax_long_test(testReason) + def decorated_test() -> None: + self.assertEqual(testReason, self._long_test_active_reason) + + self.assertEqual(None, self._long_test_active_reason) + decorated_test() + self.assertEqual(None, self._long_test_active_reason) diff --git a/ax/utils/common/testutils.py b/ax/utils/common/testutils.py index 5739ee88076..8efde4b41db 100644 --- a/ax/utils/common/testutils.py +++ b/ax/utils/common/testutils.py @@ -279,9 +279,10 @@ def custom_import(name: str, *args: Any, **kwargs: Any) -> Any: class TestCase(fake_filesystem_unittest.TestCase): """The base Ax test case, contains various helper functions to write unittests.""" - MAX_TEST_SECONDS = 480 + MAX_TEST_SECONDS = 60 NUMBER_OF_PROFILER_LINES_TO_OUTPUT = 20 PROFILE_TESTS = False + _long_test_active_reason: Optional[str] = None def __init__(self, methodName: str = "runTest") -> None: def signal_handler(signum: int, frame: Optional[FrameType]) -> None: @@ -292,7 +293,20 @@ def signal_handler(signum: int, frame: Optional[FrameType]) -> None: message += ( " To see a profiler output, set `TestCase.PROFILE_TESTS` to `True`." ) - logger.warning(message) + + if self._long_test_active_reason is None: + message += ( + " To specify a reason for a long running test," + + " utilize the @ax_long_test decorator. If your test " + + "is long because it's doing modeling, please use the " + + "@fast_botorch_optimize decorator and see if that helps." + ) + raise TimeoutError(message) + else: + message += ( + " Reason for long running test: " + self._long_test_active_reason + ) + logger.warning(message) super().__init__(methodName=methodName) signal.signal(signal.SIGALRM, signal_handler) @@ -484,6 +498,13 @@ def _print_profiler_output(self) -> None: for line in output[-self.NUMBER_OF_PROFILER_LINES_TO_OUTPUT :]: print(line) + @classmethod + @contextlib.contextmanager + def ax_long_test(cls, reason: Optional[str]) -> Generator[None, None, None]: + cls._long_test_active_reason = reason + yield + cls._long_test_active_reason = None + # This list is taken from the python standard library # pyre-fixme[4]: Attribute must be annotated. failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual)