Skip to content

Commit

Permalink
plot top n features in contours (#2291)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2291

This diff updates report_utils to plot only the top features, ranked by feature importance, rather than skipping plotting altogether when the number of features grows too large.

* Adds `parameters_to_use` arg to `interact_contour` and `interact_contour_plotly`, which subsets which range parameters are plotted in contour plots and sets their order.
* Adds `importance` arg to  `_get_objective_v_param_plots` ([pointer](https://www.internalfb.com/diff/D50501707?permalink=371874009149429)) to subset the top n parameters by importance on a per-metric basis, then provide those as `parameters_to_use` when calling `interact_contour_ploty`.

Reviewed By: dme65

Differential Revision: D50501707

fbshipit-source-id: fde36494e38d19fc6f47fb9550f85740123de300
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Apr 2, 2024
1 parent 95bafb0 commit 2967164
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 46 deletions.
18 changes: 17 additions & 1 deletion ax/plot/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import re
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import plotly.graph_objs as go
Expand Down Expand Up @@ -343,6 +343,7 @@ def interact_contour_plotly(
lower_is_better: bool = False,
fixed_features: Optional[ObservationFeatures] = None,
trial_index: Optional[int] = None,
parameters_to_use: Optional[List[str]] = None,
) -> go.Figure:
"""Create interactive plot with predictions for a 2-d slice of the parameter
space.
Expand All @@ -362,6 +363,8 @@ def interact_contour_plotly(
fixed_features: An ObservationFeatures object containing the values of
features (including non-parameter features like context) to be set
in the slice.
parameters_to_use: List of parameters to use in the plot, in the order they
should appear. If None or empty list, use all parameters.
Returns:
go.Figure: interactive plot of objective vs. parameters
Expand All @@ -378,6 +381,15 @@ def interact_contour_plotly(
slice_values["TRIAL_PARAM"] = str(trial_index)

range_parameters = get_range_parameters(model, min_num_values=5)
if parameters_to_use is not None:
if len(parameters_to_use) <= 1:
raise ValueError(
"Contour plots require two or more parameters. "
f"Got {parameters_to_use=}."
)
# Subset range parameters and put them in the same order as parameters_to_use.
range_param_name_dict = {p.name: p for p in range_parameters}
range_parameters = [range_param_name_dict[pname] for pname in parameters_to_use]
plot_data, _, _ = get_plot_data(
model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features
)
Expand Down Expand Up @@ -886,6 +898,7 @@ def interact_contour(
lower_is_better: bool = False,
fixed_features: Optional[ObservationFeatures] = None,
trial_index: Optional[int] = None,
parameters_to_use: Optional[List[str]] = None,
) -> AxPlotConfig:
"""Create interactive plot with predictions for a 2-d slice of the parameter
space.
Expand All @@ -905,6 +918,8 @@ def interact_contour(
fixed_features: An ObservationFeatures object containing the values of
features (including non-parameter features like context) to be set
in the slice.
parameters_to_use: List of parameters to use in the plot, in the order they
should appear. If None or empty list, use all parameters.
Returns:
AxPlotConfig: interactive plot of objective vs. parameters
Expand All @@ -920,6 +935,7 @@ def interact_contour(
lower_is_better=lower_is_better,
fixed_features=fixed_features,
trial_index=trial_index,
parameters_to_use=parameters_to_use,
),
plot_type=AxPlotTypes.GENERIC,
)
28 changes: 26 additions & 2 deletions ax/plot/tests/test_contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
plot_contour_plotly,
)
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_high_dimensional_branin_experiment,
)
from ax.utils.testing.mock import fast_botorch_optimize


Expand All @@ -43,7 +46,7 @@ def test_Contours(self) -> None:
self.assertIsInstance(plot, go.Figure)
plot = interact_contour(model, list(model.metric_names)[0])
self.assertIsInstance(plot, AxPlotConfig)
plot = plot = plot_contour(
plot = plot_contour(
model, model.parameters[0], model.parameters[1], list(model.metric_names)[0]
)
self.assertIsInstance(plot, AxPlotConfig)
Expand All @@ -57,3 +60,24 @@ def test_Contours(self) -> None:
for text in d["text"]:
for tt in tooltips:
self.assertTrue(tt in text)

exp = get_high_dimensional_branin_experiment(with_batch=True)
exp.trials[0].run()
model = Models.BOTORCH_MODULAR(
experiment=exp,
data=exp.fetch_data(),
)
with self.assertRaisesRegex(
ValueError, "Contour plots require two or more parameters"
):
interact_contour_plotly(
model, list(model.metric_names)[0], parameters_to_use=["foo"]
)
for i in [2, 3]:
parameters_to_use = model.parameters[:i]
plot = interact_contour_plotly(
model, list(model.metric_names)[0], parameters_to_use=parameters_to_use
)
# pyre-ignore[16]: `plotly.graph_objs.graph_objs.Figure`
# has no attribute `layout`.
self.assertEqual(len(plot.layout.updatemenus[0].buttons), i)
39 changes: 32 additions & 7 deletions ax/service/tests/test_report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_get_standard_plots_moo(self) -> None:
plots = get_standard_plots(
experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data())
)
self.assertEqual(len(log.output), 2)
self.assertEqual(len(log.output), 4)
self.assertIn(
"Pareto plotting not supported for experiments with relative objective "
"thresholds.",
Expand All @@ -433,6 +433,14 @@ def test_get_standard_plots_moo(self) -> None:
"Failed to compute global feature sensitivities:",
log.output[1],
)
self.assertIn(
"Created contour plots for metric branin_b and parameters ['x2', 'x1']",
log.output[2],
)
self.assertIn(
"Created contour plots for metric branin_a and parameters ['x2', 'x1']",
log.output[3],
)
self.assertEqual(len(plots), 6)

@fast_botorch_optimize
Expand Down Expand Up @@ -516,16 +524,33 @@ def test_skip_contour_high_dimensional(self) -> None:
experiment=exp,
data=exp.fetch_data(),
)
with self.assertLogs(logger="ax", level=INFO) as log:
_get_objective_v_param_plots(
experiment=exp, model=model, max_num_contour_plots=2
)
self.assertEqual(len(log.output), 1)
self.assertIn(
"Created contour plots for metric objective and parameters",
log.output[0],
)
with self.assertLogs(logger="ax", level=WARN) as log:
_get_objective_v_param_plots(experiment=exp, model=model)
_get_objective_v_param_plots(
experiment=exp, model=model, max_num_contour_plots=1
)
self.assertEqual(len(log.output), 1)
self.assertIn("Skipping creation of 2450 contour plots", log.output[0])
self.assertIn(
"Skipping creation of contour plots",
log.output[0],
)
with self.assertLogs(logger="ax", level=WARN) as log:
_get_objective_v_param_plots(
experiment=exp, model=model, max_num_slice_plots=10
experiment=exp,
model=model,
max_num_contour_plots=1,
max_num_slice_plots=10,
)
# Adds two more warnings.
self.assertEqual(len(log.output), 3)
self.assertIn("Skipping creation of 50 slice plots", log.output[1])
# Creates two warnings, one for slice plots and one for contour plots.
self.assertEqual(len(log.output), 2)

def test_get_metric_name_pairs(self) -> None:
exp = get_branin_experiment(with_trial=True)
Expand Down
117 changes: 83 additions & 34 deletions ax/service/utils/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ def _get_objective_trace_plot(
def _get_objective_v_param_plots(
experiment: Experiment,
model: ModelBridge,
max_num_slice_plots: int = 50,
importance: Optional[
Union[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]
] = None,
# Chosen to take ~1min on local benchmarks.
max_num_slice_plots: int = 200,
# Chosen to take ~2min on local benchmarks.
max_num_contour_plots: int = 20,
) -> List[go.Figure]:
search_space = experiment.search_space
Expand All @@ -175,8 +180,9 @@ def _get_objective_v_param_plots(
"`RangeParameter`. Returning an empty list."
)
return []
range_param_names = [param.name for param in range_params]
num_range_params = len(range_params)
num_metrics = len(experiment.metrics)
num_metrics = len(model.metric_names)
num_slice_plots = num_range_params * num_metrics
output_plots = []
if num_slice_plots <= max_num_slice_plots:
Expand All @@ -197,32 +203,56 @@ def _get_objective_v_param_plots(
warning_plot = _warn_and_create_warning_plot(warning_msg=warning_msg)
output_plots.append(warning_plot)

num_contour_plots = num_range_params * (num_range_params - 1) * num_metrics
if num_range_params > 1 and num_contour_plots <= max_num_contour_plots:
# contour plots
try:
with gpytorch.settings.max_eager_kernel_size(float("inf")):
output_plots += [
interact_contour_plotly(
model=not_none(model),
metric_name=metric_name,
)
for metric_name in model.metric_names
]
# `mean shape torch.Size` RunTimeErrors, pending resolution of
# https://github.com/cornellius-gp/gpytorch/issues/1853
except RuntimeError as e:
logger.warning(f"Contour plotting failed with error: {e}.")
elif num_contour_plots > max_num_contour_plots:
# contour plots
num_contour_per_metric = max_num_contour_plots // num_metrics
if num_contour_per_metric < 2:
warning_msg = (
f"Skipping creation of {num_contour_plots} contour plots since that "
f"exceeds <br>`max_num_contour_plots = {max_num_contour_plots}`."
"Skipping creation of contour plots since that requires <br>"
"`max_num_contour_plots >= 2 * num_metrics`. Got "
f"{max_num_contour_plots=} and {num_metrics=}."
"<br>Users can plot individual contour plots with the <br>python "
"function ax.plot.contour.plot_contour_plotly."
)
# TODO: return a warning here then convert to a plot/message/etc. downstream.
warning_plot = _warn_and_create_warning_plot(warning_msg=warning_msg)
output_plots.append(warning_plot)
elif num_range_params > 1:
# Using n params yields n * (n - 1) contour plots, so we use the number of
# params that yields the desired number of plots (solved using quadratic eqn)
num_params_per_metric = int(0.5 + (0.25 + num_contour_per_metric) ** 0.5)
try:
for metric_name in model.metric_names:
if importance is not None:
range_params_sens_for_metric = {
k: v
for k, v in importance[metric_name].items()
if k in range_param_names
}
# sort the params by their sensitivity
params_to_use = sorted(
range_params_sens_for_metric,
key=lambda x: range_params_sens_for_metric[x],
reverse=True,
)[:num_params_per_metric]
# if sens is not available, just use the first num_features_per_metric.
else:
params_to_use = range_param_names[:num_params_per_metric]
with gpytorch.settings.max_eager_kernel_size(float("inf")):
output_plots.append(
interact_contour_plotly(
model=not_none(model),
metric_name=metric_name,
parameters_to_use=params_to_use,
)
)
logger.info(
f"Created contour plots for metric {metric_name} and parameters "
f"{params_to_use}."
)
# `mean shape torch.Size` RunTimeErrors, pending resolution of
# https://github.com/cornellius-gp/gpytorch/issues/1853
except RuntimeError as e:
logger.warning(f"Contour plotting failed with error: {e}.")
return output_plots


Expand Down Expand Up @@ -394,12 +424,44 @@ def get_standard_plots(
except Exception as e:
logger.exception(f"Scatter plot failed with error: {e}")

# Compute feature importance ("sensitivity") to select most important
# features to plot.
sens = None
importance_measure = ""
if global_sensitivity_analysis and isinstance(model, TorchModelBridge):
try:
logger.debug("Starting global sensitivity analysis.")
sens = ax_parameter_sens(model, order="total")
importance_measure = (
'<a href="https://en.wikipedia.org/wiki/Variance-based_'
'sensitivity_analysis">Variance-based sensitivity analysis</a>'
)
logger.debug("Finished global sensitivity analysis.")
except Exception as e:
logger.info(f"Failed to compute global feature sensitivities: {e}")
if sens is None:
try:
sens = {
metric_name: model.feature_importances(metric_name)
for i, metric_name in enumerate(sorted(model.metric_names))
}
except Exception as e:
logger.info(f"Failed to compute feature importances: {e}")

try:
logger.debug("Starting objective vs. param plots.")
# importance is the absolute value of sensitivity.
importance = None
if sens is not None:
importance = {
k: {j: np.absolute(sens[k][j]) for j in sens[k].keys()}
for k in sens.keys()
}
output_plot_list.extend(
_get_objective_v_param_plots(
experiment=experiment,
model=model,
importance=importance,
)
)
logger.debug("Finished objective vs. param plots.")
Expand All @@ -414,19 +476,6 @@ def get_standard_plots(
logger.exception(f"Cross-validation plot failed with error: {e}")

# sensitivity plot
sens = None
importance_measure = ""
try:
if global_sensitivity_analysis and isinstance(model, TorchModelBridge):
logger.debug("Starting global sensitivity analysis.")
sens = ax_parameter_sens(model, order="total", signed=True)
importance_measure = (
'<a href="https://en.wikipedia.org/wiki/Variance-based_'
'sensitivity_analysis">Variance-based sensitivity analysis</a>'
)
logger.debug("Finished global sensitivity analysis.")
except Exception as e:
logger.info(f"Failed to compute global feature sensitivities: {e}")
try:
logger.debug("Starting feature importance plot.")
feature_importance_plot = plot_feature_importance_by_feature_plotly(
Expand Down
9 changes: 7 additions & 2 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def get_experiment_with_observations(
return exp


def get_high_dimensional_branin_experiment() -> Experiment:
def get_high_dimensional_branin_experiment(with_batch: bool = False) -> Experiment:
search_space = SearchSpace(
# pyre-fixme[6]: In call `SearchSpace.__init__`, for 1st parameter `parameters`
# expected `List[Parameter]` but got `List[RangeParameter]`.
Expand Down Expand Up @@ -803,12 +803,17 @@ def get_high_dimensional_branin_experiment() -> Experiment:
)
)

return Experiment(
exp = Experiment(
name="high_dimensional_branin_experiment",
search_space=search_space,
optimization_config=optimization_config,
runner=SyntheticRunner(),
)
if with_batch:
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=15)
exp.new_batch_trial().add_generator_run(sobol_run)
return exp


##############################
Expand Down

0 comments on commit 2967164

Please sign in to comment.