diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index ada394dc9fb..0a3e5a2dbeb 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -944,6 +944,12 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float: for trial in self.experiment.trials.values() if trial.status.is_completed ] + try: + # The transitions are only available with step based GS. + # TODO: Clean up once transitions are available for all GS. + model_transitions = self.generation_strategy.model_transitions + except UnsupportedError: + model_transitions = None return optimization_trace_single_method( y=( np.minimum.accumulate(best_objectives, axis=1) @@ -954,7 +960,7 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float: title="Model performance vs. # of iterations", ylabel=objective_name.capitalize(), hover_labels=hover_labels, - model_transitions=self.generation_strategy.model_transitions, + model_transitions=model_transitions, ) def get_contour_plot( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 7ed8639ef28..219f9b74db7 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -31,6 +31,7 @@ ) from ax.core.parameter_constraint import OrderConstraint from ax.core.search_space import HierarchicalSearchSpace +from ax.core.trial import Trial from ax.core.types import ( ComparisonOp, TEvaluationOutcome, @@ -48,7 +49,12 @@ from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.metrics.branin import branin from ax.modelbridge.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.modelbridge.generation_strategy import ( + GenerationNode, + GenerationStep, + GenerationStrategy, +) +from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import Models @@ -2906,6 +2912,35 @@ def test_SingleTaskGP_log_unordered_categorical_parameters(self) -> None: self.assertTrue(found_no_log) + def test_with_node_based_gs(self) -> None: + sobol_gs = GenerationStrategy( + name="Sobol", + nodes=[ + GenerationNode( + node_name="Sobol", + model_specs=[ModelSpec(model_enum=Models.SOBOL)], + ) + ], + ) + ax_client = get_branin_optimization(generation_strategy=sobol_gs) + params, idx = ax_client.get_next_trial() + ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)}) + + self.assertEqual(ax_client.generation_strategy.name, "Sobol") + self.assertEqual( + checked_cast( + Trial, ax_client.experiment.trials[0] + )._generator_run._model_key, + "Sobol", + ) + with mock.patch( + "ax.service.ax_client.optimization_trace_single_method" + ) as mock_plot: + ax_client.get_optimization_trace() + mock_plot.assert_called_once() + call_kwargs = mock_plot.call_args.kwargs + self.assertIsNone(call_kwargs["model_transitions"]) + # Utility functions for testing get_model_predictions without calling # get_next_trial. Create Ax Client with an experiment where