Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinforce testenv in ax test case #2530

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ax/storage/sqa_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"_steps",
"analysis_scheduler",
"_nodes",
# ``status_quo_weight_override`` is a field on ``BatchTrial`` not in the
# "trial_v2" table
# TODO(T193258337)
"_status_quo_weight_override",
}
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."

Expand Down
3 changes: 3 additions & 0 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ class Keys(str, Enum):

DEFAULT_WINSORIZATION_LIMITS_MINIMIZATION: Tuple[float, float] = (0.0, 0.2)
DEFAULT_WINSORIZATION_LIMITS_MAXIMIZATION: Tuple[float, float] = (0.2, 0.0)

TESTENV_ENV_KEY = "TESTENV"
TESTENV_ENV_VAL = "True"
8 changes: 8 additions & 0 deletions ax/utils/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import linecache
import logging
import os
import signal
import sys
import types
Expand All @@ -39,6 +40,7 @@
import numpy as np
from ax.exceptions.core import AxParameterWarning
from ax.utils.common.base import Base
from ax.utils.common.constants import TESTENV_ENV_KEY, TESTENV_ENV_VAL
from ax.utils.common.equality import object_attribute_dicts_find_unequal_fields
from ax.utils.common.logger import get_logger
from botorch.exceptions.warnings import InputDataWarning
Expand Down Expand Up @@ -301,6 +303,12 @@ def signal_handler(signum: int, frame: Optional[FrameType]) -> None:

super().__init__(methodName=methodName)
signal.signal(signal.SIGALRM, signal_handler)
# This is set to indicate we are running in a test environment. Code can check
# this to:
# * more strictly enforce SQL encoding
# (https://github.com/facebook/Ax/blob/main/ax/storage/sqa_store/save.py#L598)
# * avoid actions that will affect product environments
os.environ[TESTENV_ENV_KEY] = TESTENV_ENV_VAL

def setUp(self) -> None:
"""
Expand Down