diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 905f140bd1c..69667f341cb 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -852,15 +852,21 @@ def _make_evaluations_and_data( ) return evaluations, data + def _raise_cant_attach_if_completed(self) -> None: + """ + Helper method used by `validate_can_attach_data` to raise an error if + the user tries to attach data to a completed trial. Subclasses such as + `Trial` override this by suggesting a remediation. + """ + raise UnsupportedError( + f"Trial {self.index} already has status 'completed', so data cannot " + "be attached." + ) + def _validate_can_attach_data(self) -> None: """Determines whether a trial is in a state that can be attached data.""" if self.status.is_completed: - raise UnsupportedError( - f"Trial {self.index} has already been completed with data." - "To add more data to it (for example, for a different metric), " - "use `Trial.update_trial_data()` or " - "BatchTrial.update_batch_trial_data()." - ) + self._raise_cant_attach_if_completed() if self.status.is_abandoned or self.status.is_failed: raise UnsupportedError( f"Trial {self.index} has been marked {self.status.name}, so it " diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 6f5c105cfea..5d9e3385bc9 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -19,6 +19,7 @@ from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.parameter import FixedParameter, ParameterType from ax.core.search_space import SearchSpace +from ax.exceptions.core import UnsupportedError from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast @@ -47,6 +48,16 @@ def setUp(self) -> None: self.weights = weights[1:] self.batch.add_arms_and_weights(arms=self.arms, weights=self.weights) + def ftest__validate_can_attach_data(self) -> None: + self.batch.mark_running(no_runner_required=True) + self.batch.mark_completed() + + expected_msg = ( + "Trial 0 already has status 'completed', so data cannot " "be attached." + ) + with self.assertRaisesRegex(UnsupportedError, expected_msg): + self.batch._validate_can_attach_data() + def test_Eq(self) -> None: new_batch_trial = self.experiment.new_batch_trial() self.assertNotEqual(self.batch, new_batch_trial) diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 36aaf3bbd00..ab99593b86e 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -16,7 +16,7 @@ from ax.core.data import Data from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.runner import Runner -from ax.exceptions.core import UserInputError +from ax.exceptions.core import UnsupportedError, UserInputError from ax.runners.synthetic import SyntheticRunner from ax.utils.common.result import Ok from ax.utils.common.testutils import TestCase @@ -60,6 +60,18 @@ def setUp(self) -> None: def tearDown(self) -> None: self.mock_supports_trial_type.stop() + def ftest__validate_can_attach_data(self) -> None: + self.trial.mark_running(no_runner_required=True) + self.trial.mark_completed() + + expected_msg = ( + "Trial 0 has already been completed with data. To add more data to " + # "it (for example, for a different metric), use " + # "`Trial.update_trial_data()`." + ) + with self.assertRaisesRegex(UnsupportedError, expected_msg): + self.trial._validate_can_attach_data() + def test_eq(self) -> None: new_trial = self.experiment.new_trial() self.assertNotEqual(self.trial, new_trial) diff --git a/ax/core/trial.py b/ax/core/trial.py index 74a49bd9b10..29dc6e078ed 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -19,10 +19,12 @@ from ax.core.data import Data from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.types import TCandidateMetadata, TEvaluationOutcome +from ax.exceptions.core import UnsupportedError from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.typeutils import not_none + logger: Logger = get_logger(__name__) ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6 @@ -351,3 +353,15 @@ def clone_to( new_trial.add_generator_run(self.generator_run.clone()) self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial + + def _raise_cant_attach_if_completed(self) -> None: + """ + Helper method used by `validate_can_attach_data` to raise an error if + the user tries to attach data to a completed trial. Subclasses such as + `Trial` override this by suggesting a remediation. + """ + raise UnsupportedError( + f"Trial {self.index} has already been completed with data. " + "To add more data to it (for example, for a different metric), " + f"use `{self.__class__.__name__}.update_trial_data()`." + )