diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 48973f09896..2aa9cad3ece 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -546,8 +546,7 @@ def get_next_trial( ) trial.mark_running(no_runner_required=True) self._save_or_update_trial_in_db_if_possible( - experiment=self.experiment, - trial=trial, + experiment=self.experiment, trial=trial ) # TODO[T79183560]: Ensure correct handling of generator run when using # foreign keys. @@ -1294,6 +1293,9 @@ def stop_trial_early(self, trial_index: int) -> None: trial = self.get_trial(trial_index) trial.mark_early_stopped() logger.info(f"Early stopped trial {trial_index}.") + self._save_or_update_trial_in_db_if_possible( + experiment=self.experiment, trial=trial + ) def estimate_early_stopping_savings(self, map_key: Optional[str] = None) -> float: """Estimate early stopping savings using progressions of the MapMetric present @@ -1625,8 +1627,7 @@ def _update_trial_with_raw_data( trial.mark_completed() self._save_or_update_trial_in_db_if_possible( - experiment=self.experiment, - trial=trial, + experiment=self.experiment, trial=trial ) return update_info diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index d8902c52703..d232efafed8 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1963,6 +1963,18 @@ def test_sqa_storage(self) -> None: # Original experiment should still be in DB and not have been overwritten. self.assertEqual(len(ax_client.experiment.trials), 5) + # Attach an early stopped trial. + parameters, trial_index = ax_client.get_next_trial() + ax_client.stop_trial_early(trial_index=trial_index) + + # Reload experiment and check that trial status is accurate. + ax_client_new = AxClient(db_settings=db_settings) + ax_client_new.load_experiment_from_database("test_experiment") + self.assertEqual( + ax_client.experiment.trials_by_status, + ax_client_new.experiment.trials_by_status, + ) + def test_overwrite(self) -> None: init_test_engine_and_session_factory(force_init=True) ax_client = AxClient()