Skip to content

Commit

Permalink
Unblock node based GS in AxClient.get_optimization_trace (#2283)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2283

This would previously raise an `UnsupportedError` due to `model_transitions` being avaible only with step based GS.

Reviewed By: Balandat

Differential Revision: D54974577

fbshipit-source-id: 0118e0b5873b0e2ac5db29090e4816ce60c655b3
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 16, 2024
1 parent 426a66b commit 3ed20ea
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,12 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float:
for trial in self.experiment.trials.values()
if trial.status.is_completed
]
try:
# The transitions are only available with step based GS.
# TODO: Clean up once transitions are available for all GS.
model_transitions = self.generation_strategy.model_transitions
except UnsupportedError:
model_transitions = None
return optimization_trace_single_method(
y=(
np.minimum.accumulate(best_objectives, axis=1)
Expand All @@ -954,7 +960,7 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float:
title="Model performance vs. # of iterations",
ylabel=objective_name.capitalize(),
hover_labels=hover_labels,
model_transitions=self.generation_strategy.model_transitions,
model_transitions=model_transitions,
)

def get_contour_plot(
Expand Down
37 changes: 36 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ax.core.parameter_constraint import OrderConstraint
from ax.core.search_space import HierarchicalSearchSpace
from ax.core.trial import Trial
from ax.core.types import (
ComparisonOp,
TEvaluationOutcome,
Expand All @@ -48,7 +49,12 @@
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.metrics.branin import branin
from ax.modelbridge.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Models

Expand Down Expand Up @@ -2906,6 +2912,35 @@ def test_SingleTaskGP_log_unordered_categorical_parameters(self) -> None:

self.assertTrue(found_no_log)

def test_with_node_based_gs(self) -> None:
sobol_gs = GenerationStrategy(
name="Sobol",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[ModelSpec(model_enum=Models.SOBOL)],
)
],
)
ax_client = get_branin_optimization(generation_strategy=sobol_gs)
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})

self.assertEqual(ax_client.generation_strategy.name, "Sobol")
self.assertEqual(
checked_cast(
Trial, ax_client.experiment.trials[0]
)._generator_run._model_key,
"Sobol",
)
with mock.patch(
"ax.service.ax_client.optimization_trace_single_method"
) as mock_plot:
ax_client.get_optimization_trace()
mock_plot.assert_called_once()
call_kwargs = mock_plot.call_args.kwargs
self.assertIsNone(call_kwargs["model_transitions"])


# Utility functions for testing get_model_predictions without calling
# get_next_trial. Create Ax Client with an experiment where
Expand Down

0 comments on commit 3ed20ea

Please sign in to comment.