From ea37e6427422b0f2e8f7d5ca857f3068605a7444 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Tue, 5 Mar 2024 12:36:01 -0800 Subject: [PATCH] copying "cross_validation_helper" code from ax.plot (#2249) Summary: To prepare of creating the "CrossValidationPlot" module, this change imports the dependent code from "ax.plots". It also cleans up the original code by breaking it apart into different helper files, and trimming out methods which are not used to create this plot. Some new unit tests are being added as well. Differential Revision: D54495372 --- ax/analysis/helpers/color_helpers.py | 33 ++ ax/analysis/helpers/constants.py | 97 ++++++ .../helpers/cross_validation_helpers.py | 300 ++++++++++++++++ ax/analysis/helpers/layout_helpers.py | 96 ++++++ ax/analysis/helpers/plot_helpers.py | 78 +++++ ax/analysis/helpers/scatter_helpers.py | 325 ++++++++++++++++++ .../tests/test_cross_validation_helpers.py | 121 +++++++ .../helpers/tests/test_scatter_helpers.py | 42 +++ 8 files changed, 1092 insertions(+) create mode 100644 ax/analysis/helpers/color_helpers.py create mode 100644 ax/analysis/helpers/constants.py create mode 100644 ax/analysis/helpers/cross_validation_helpers.py create mode 100644 ax/analysis/helpers/layout_helpers.py create mode 100644 ax/analysis/helpers/plot_helpers.py create mode 100644 ax/analysis/helpers/scatter_helpers.py create mode 100644 ax/analysis/helpers/tests/test_cross_validation_helpers.py create mode 100644 ax/analysis/helpers/tests/test_scatter_helpers.py diff --git a/ax/analysis/helpers/color_helpers.py b/ax/analysis/helpers/color_helpers.py new file mode 100644 index 00000000000..575b641f014 --- /dev/null +++ b/ax/analysis/helpers/color_helpers.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from numbers import Real +from typing import List, Tuple + +# type aliases +TRGB = Tuple[Real, ...] + + +def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str: + """Convert RGB tuple to an RGBA string.""" + return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha) + + +def plotly_color_scale( + list_of_rgb_tuples: List[TRGB], + reverse: bool = False, + alpha: float = 1, +) -> List[Tuple[float, str]]: + """Convert list of RGB tuples to list of tuples, where each tuple is + break in [0, 1] and stringified RGBA color. + """ + if reverse: + list_of_rgb_tuples = list_of_rgb_tuples[::-1] + return [ + (round(i / (len(list_of_rgb_tuples) - 1), 3), rgba(rgb)) + for i, rgb in enumerate(list_of_rgb_tuples) + ] diff --git a/ax/analysis/helpers/constants.py b/ax/analysis/helpers/constants.py new file mode 100644 index 00000000000..c5af778ebac --- /dev/null +++ b/ax/analysis/helpers/constants.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import enum + +# Constants used for numerous plots +CI_OPACITY = 0.4 +DECIMALS = 3 +Z = 1.96 + + +# color constants used for plotting +class COLORS(enum.Enum): + STEELBLUE = (128, 177, 211) + CORAL = (251, 128, 114) + TEAL = (141, 211, 199) + PINK = (188, 128, 189) + LIGHT_PURPLE = (190, 186, 218) + ORANGE = (253, 180, 98) + + +# colors to be used for plotting discrete series +# pyre-fixme[5]: Global expression must be annotated. +DISCRETE_COLOR_SCALE = [ + COLORS.STEELBLUE.value, + COLORS.CORAL.value, + COLORS.PINK.value, + COLORS.LIGHT_PURPLE.value, + COLORS.ORANGE.value, + COLORS.TEAL.value, +] + +# 11-class PiYG from ColorBrewer (for contour plots) +GREEN_PINK_SCALE = [ + (142, 1, 82), + (197, 27, 125), + (222, 119, 174), + (241, 182, 218), + (253, 224, 239), + (247, 247, 247), + (230, 245, 208), + (184, 225, 134), + (127, 188, 65), + (77, 146, 33), + (39, 100, 25), +] +GREEN_SCALE = [ + (247, 252, 253), + (229, 245, 249), + (204, 236, 230), + (153, 216, 201), + (102, 194, 164), + (65, 174, 118), + (35, 139, 69), + (0, 109, 44), + (0, 68, 27), +] +BLUE_SCALE = [ + (255, 247, 251), + (236, 231, 242), + (208, 209, 230), + (166, 189, 219), + (116, 169, 207), + (54, 144, 192), + (5, 112, 176), + (3, 78, 123), +] +# 24 Class Mixed Color Palette +# Source: https://graphicdesign.stackexchange.com/a/3815 +MIXED_SCALE = [ + (2, 63, 165), + (125, 135, 185), + (190, 193, 212), + (214, 188, 192), + (187, 119, 132), + (142, 6, 59), + (74, 111, 227), + (133, 149, 225), + (181, 187, 227), + (230, 175, 185), + (224, 123, 145), + (211, 63, 106), + (17, 198, 56), + (141, 213, 147), + (198, 222, 199), + (234, 211, 198), + (240, 185, 141), + (239, 151, 8), + (15, 207, 192), + (156, 222, 214), + (213, 234, 231), + (243, 225, 235), + (246, 196, 225), + (247, 156, 212), +] diff --git a/ax/analysis/helpers/cross_validation_helpers.py b/ax/analysis/helpers/cross_validation_helpers.py new file mode 100644 index 00000000000..b6fac5b73fa --- /dev/null +++ b/ax/analysis/helpers/cross_validation_helpers.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import plotly.graph_objs as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.layout_helpers import layout_format, updatemenus_format +from ax.analysis.helpers.scatter_helpers import ( + _error_scatter_data, + _error_scatter_trace, + PlotData, + PlotInSampleArm, + PlotMetric, +) + +from ax.exceptions.core import UnsupportedPlotError + +from ax.modelbridge.cross_validation import CVResult + + +# Helper functions for plotting model fits +def get_min_max_with_errors( + x: List[float], y: List[float], sd_x: List[float], sd_y: List[float] +) -> Tuple[float, float]: + """Get min and max of a bivariate dataset (across variables). + + Args: + x: point estimate of x variable. + y: point estimate of y variable. + sd_x: standard deviation of x variable. + sd_y: standard deviation of y variable. + + Returns: + min_: minimum of points, including uncertainty. + max_: maximum of points, including uncertainty. + + """ + min_ = min( + min(np.array(x) - np.multiply(sd_x, Z)), min(np.array(y) - np.multiply(sd_y, Z)) + ) + max_ = max( + max(np.array(x) + np.multiply(sd_x, Z)), max(np.array(y) + np.multiply(sd_y, Z)) + ) + return min_, max_ + + +def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]: + """Diagonal line trace from (min_, min_) to (max_, max_). + + Args: + min_: minimum to be used for starting point of line. + max_: maximum to be used for ending point of line. + visible: if True, trace is set to visible. + + """ + return go.Scatter( + x=[min_, max_], + y=[min_, max_], + line=dict(color="black", width=2, dash="dot"), # noqa: C408 + mode="lines", + hoverinfo="none", + visible=visible, + showlegend=False, + ) + + +def default_value_se_raw(se_raw: Optional[List[float]], out_length: int) -> List[float]: + """ + Takes a list of standard errors and maps edge cases to default list + of floats. + + """ + new_se_raw = ( + [0.0 if np.isnan(se) else se for se in se_raw] + if se_raw is not None + else [0.0] * out_length + ) + return new_se_raw + + +def obs_vs_pred_dropdown_plot( + data: PlotData, + rel: bool, + show_context: bool = False, + xlabel: str = "Actual Outcome", + ylabel: str = "Predicted Outcome", + autoset_axis_limits: bool = True, +) -> go.Figure: + """Plot a dropdown plot of observed vs. predicted values from a model. + NOTE: If relative, cannot show additional context + on arm details on hover + + Args: + data: a name tuple storing observed and predicted data + from a model. + rel: if True, plot metrics relative to the status quo. + show_context: Show arm detail context on hover. + xlabel: Label for x-axis. + ylabel: Label for y-axis. + autoset_axis_limits: Automatically try to set the limit for each axis to focus + on the region of interest. + """ + traces = [] + metric_dropdown = [] + layout_axis_range = [] + status_quo_arm = None + if rel and data.status_quo_name is not None: + if show_context: + raise UnsupportedPlotError( + "This plot does not support both context and relativization at " + "the same time." + ) + status_quo_arm = data.in_sample[data.status_quo_name] + + for i, metric in enumerate(data.metrics): + y_raw, se_raw, y_hat, se_hat = _error_scatter_data( + list(data.in_sample.values()), + y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel), + x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel), + status_quo_arm=status_quo_arm, + ) + se_raw = default_value_se_raw(se_raw=se_raw, out_length=len(y_raw)) + + min_, max_ = get_min_max_with_errors(x=y_raw, y=y_hat, sd_x=se_raw, sd_y=se_hat) + if autoset_axis_limits: + y_raw_np = np.array(y_raw) + # TODO: replace interpolation->method once it becomes standard. + q1 = np.nanpercentile(y_raw_np, q=25, interpolation="lower").min() + q3 = np.nanpercentile(y_raw_np, q=75, interpolation="higher").max() + y_lower = q1 - 1.5 * (q3 - q1) + y_upper = q3 + 1.5 * (q3 - q1) + y_raw_np = y_raw_np.clip(y_lower, y_upper).tolist() + min_robust, max_robust = get_min_max_with_errors( + x=y_raw_np, y=y_hat, sd_x=se_raw, sd_y=se_hat + ) + y_padding = 0.05 * (max_robust - min_robust) + # Use the min/max of the limits + layout_axis_range.append( + [max(min_robust, min_) - y_padding, min(max_robust, max_) + y_padding] + ) + traces.append( + diagonal_trace( + min(min_robust, min_) - y_padding, + max(max_robust, max_) + y_padding, + visible=(i == 0), + ) + ) + else: + layout_axis_range.append(None) + traces.append(diagonal_trace(min_, max_, visible=(i == 0))) + + traces.append( + _error_scatter_trace( + arms=list(data.in_sample.values()), + hoverinfo="text", + show_arm_details_on_hover=True, + show_CI=True, + show_context=show_context, + status_quo_arm=status_quo_arm, + visible=(i == 0), + x_axis_label=xlabel, + x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel), + y_axis_label=ylabel, + y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel), + ) + ) + + # only the first two traces are visible (corresponding to first outcome + # in dropdown) + is_visible = [False] * (len(data.metrics) * 2) + is_visible[2 * i] = True + is_visible[2 * i + 1] = True + + # on dropdown change, restyle + metric_dropdown.append( + { + "args": [ + {"visible": is_visible}, + { + "xaxis.range": layout_axis_range[-1], + "yaxis.range": layout_axis_range[-1], + }, + ], + "label": metric, + "method": "update", + } + ) + + updatemenus = updatemenus_format(metric_dropdown=metric_dropdown) + layout = layout_format( + layout_axis_range_value=layout_axis_range[0], + xlabel=xlabel, + ylabel=ylabel, + updatemenus=updatemenus, + ) + + return go.Figure(data=traces, layout=layout) + + +def get_cv_plot_data( + cv_results: List[CVResult], label_dict: Optional[Dict[str, str]] +) -> PlotData: + """Construct PlotData from cv_results, mapping observed to y and se, + and predicted to y_hat and se_hat. + + Args: + cv_results: A CVResult for each observation in the training data. + label_dict: optional map from real metric names to shortened names + + Returns: + PlotData with the following fields: + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + PlotInSample arm have the fields + { + "name" + "y" + "se" + "parameters" + "y_hat" + "se_hat" + "context_stratum": (Always NONE) + } + out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]] + (NOTE: Always None) + status_quo_name: Optional[str] + (NOTE: Always None) + + """ + if len(cv_results) == 0: + return PlotData( + metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None + ) + + if label_dict is None: + label_dict = {} + # Apply label_dict to cv_results + cv_results = deepcopy(cv_results) # Copy and edit in-place + for cv_i in cv_results: + cv_i.observed.data.metric_names = [ + label_dict.get(m, m) for m in cv_i.observed.data.metric_names + ] + cv_i.predicted.metric_names = [ + label_dict.get(m, m) for m in cv_i.predicted.metric_names + ] + + # arm_name -> Arm data + insample_data: Dict[str, PlotInSampleArm] = {} + + # Get the union of all metric_names seen in predictions + metric_names = list( + set().union(*(cv_result.predicted.metric_names for cv_result in cv_results)) + ) + + for rid, cv_result in enumerate(cv_results): + arm_name = cv_result.observed.arm_name + arm_data = { + "name": cv_result.observed.arm_name, + "y": {}, + "se": {}, + "parameters": cv_result.observed.features.parameters, + "y_hat": {}, + "se_hat": {}, + "context_stratum": None, + } + for i, mname in enumerate(cv_result.observed.data.metric_names): + # pyre-fixme[16]: Optional type has no attribute `__setitem__`. + arm_data["y"][mname] = cv_result.observed.data.means[i] + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i]) + for i, mname in enumerate(cv_result.predicted.metric_names): + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["y_hat"][mname] = cv_result.predicted.means[i] + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i]) + + # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got + # `Optional[str]`. + # pyre-fixme[6]: + insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data) + return PlotData( + metrics=metric_names, + in_sample=insample_data, + out_of_sample=None, + status_quo_name=None, + ) diff --git a/ax/analysis/helpers/layout_helpers.py b/ax/analysis/helpers/layout_helpers.py new file mode 100644 index 00000000000..e77903cab35 --- /dev/null +++ b/ax/analysis/helpers/layout_helpers.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Type + +import plotly.graph_objs as go + + +def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + { + "x": 0, + "y": 1.125, + "yanchor": "top", + "xanchor": "left", + "buttons": metric_dropdown, + }, + { + "buttons": [ + { + "args": [ + { + "error_x.width": 4, + "error_x.thickness": 2, + "error_y.width": 4, + "error_y.thickness": 2, + } + ], + "label": "Yes", + "method": "restyle", + }, + { + "args": [ + { + "error_x.width": 0, + "error_x.thickness": 0, + "error_y.width": 0, + "error_y.thickness": 0, + } + ], + "label": "No", + "method": "restyle", + }, + ], + "x": 1.125, + "xanchor": "left", + "y": 0.8, + "yanchor": "middle", + }, + ] + + +def layout_format( + layout_axis_range_value: Tuple[float, float], + xlabel: str, + ylabel: str, + updatemenus: List[Dict[str, Any]], +) -> Type[go.Figure]: + layout = go.Layout( + annotations=[ + { + "showarrow": False, + "text": "Show CI", + "x": 1.125, + "xanchor": "left", + "xref": "paper", + "y": 0.9, + "yanchor": "middle", + "yref": "paper", + } + ], + xaxis={ + "range": layout_axis_range_value, + "title": xlabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + yaxis={ + "range": layout_axis_range_value, + "title": ylabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + showlegend=False, + hovermode="closest", + updatemenus=updatemenus, + width=530, + height=500, + ) + return layout diff --git a/ax/analysis/helpers/plot_helpers.py b/ax/analysis/helpers/plot_helpers.py new file mode 100644 index 00000000000..60aab0373c2 --- /dev/null +++ b/ax/analysis/helpers/plot_helpers.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from logging import Logger +from typing import Dict, List, Optional, Tuple, Union + +from ax.analysis.helpers.constants import DECIMALS, Z + +from ax.core.generator_run import GeneratorRun + +from ax.core.types import TParameterization +from ax.utils.common.logger import get_logger + + +logger: Logger = get_logger(__name__) + +# Typing alias +RawData = List[Dict[str, Union[str, float]]] + +TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]] + + +def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str: + """Format a dictionary for labels. + + Args: + param_dict: Dictionary to be formatted + name: String name of the thing being formatted. + + Returns: stringified blob. + """ + if len(param_dict) >= 10: + blob = "{} has too many items to render on hover ({}).".format( + name, len(param_dict) + ) + else: + blob = "
{}:
{}".format( + name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items()) + ) + return blob + + +def _format_CI(estimate: float, sd: float, relative: bool, zval: float = Z) -> str: + """Format confidence intervals given estimate and standard deviation. + + Args: + estimate: point estimate. + sd: standard deviation of point estimate. + relative: if True, '%' is appended. + zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs) + + Returns: formatted confidence interval. + """ + return "[{lb:.{digits}f}{perc}, {ub:.{digits}f}{perc}]".format( + lb=estimate - zval * sd, + ub=estimate + zval * sd, + digits=DECIMALS, + perc="%" if relative else "", + ) + + +def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]: + """Parses arm name into tuple suitable for reverse sorting by key + + Example: + arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] + sorted(arm_names, key=arm_name_to_sort_key, reverse=True) + ["control", "0_0", "1_2", "1_10", "10_0"] + """ + try: + trial_index, arm_index = arm_name.split("_") + return ("", -int(trial_index), -int(arm_index)) + except (ValueError, IndexError): + return (arm_name, 0, 0) diff --git a/ax/analysis/helpers/scatter_helpers.py b/ax/analysis/helpers/scatter_helpers.py new file mode 100644 index 00000000000..62c9431d4fa --- /dev/null +++ b/ax/analysis/helpers/scatter_helpers.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numbers + +from typing import ( + Any, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +import numpy as np +import plotly.graph_objs as go +from ax.analysis.helpers.color_helpers import rgba + +from ax.analysis.helpers.constants import BLUE_SCALE, CI_OPACITY, COLORS, DECIMALS, Z + +from ax.analysis.helpers.plot_helpers import ( + _format_CI, + _format_dict, + arm_name_to_sort_key, +) + +from ax.core.types import TParameterization + +from ax.exceptions.core import UnsupportedPlotError + +from ax.utils.stats.statstools import relativize + + +# Structs for plot data +class PlotMetric(NamedTuple): + """Struct for metric""" + + metric_name: str + pred: bool + rel: bool + + +class PlotInSampleArm(NamedTuple): + """Struct for in-sample arms (both observed and predicted data)""" + + name: str + parameters: TParameterization + y: Dict[str, float] + y_hat: Dict[str, float] + se: Dict[str, float] + se_hat: Dict[str, float] + context_stratum: Optional[Dict[str, Union[str, float]]] + + +class PlotOutOfSampleArm(NamedTuple): + """Struct for out-of-sample arms (only predicted data)""" + + name: str + parameters: TParameterization + y_hat: Dict[str, float] + se_hat: Dict[str, float] + context_stratum: Optional[Dict[str, Union[str, float]]] + + +class PlotData(NamedTuple): + """Struct for plot data, including both in-sample and out-of-sample arms""" + + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]] + status_quo_name: Optional[str] + + +def _error_scatter_data( + arms: Iterable[Union[PlotInSampleArm, PlotOutOfSampleArm]], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, + status_quo_arm: Optional[PlotInSampleArm] = None, +) -> Tuple[List[float], Optional[List[float]], List[float], List[float]]: + y_metric_key = "y_hat" if y_axis_var.pred else "y" + y_sd_key = "se_hat" if y_axis_var.pred else "se" + + arm_names = [a.name for a in arms] + y = [getattr(a, y_metric_key).get(y_axis_var.metric_name, np.nan) for a in arms] + y_se = [getattr(a, y_sd_key).get(y_axis_var.metric_name, np.nan) for a in arms] + + # Delta method if relative to status quo arm + if y_axis_var.rel: + if status_quo_arm is None: + raise UnsupportedPlotError( + "`status_quo_arm` cannot be None for relative effects." + ) + y_rel, y_se_rel = relativize( + means_t=y, + sems_t=y_se, + mean_c=getattr(status_quo_arm, y_metric_key).get(y_axis_var.metric_name), + sem_c=getattr(status_quo_arm, y_sd_key).get(y_axis_var.metric_name), + as_percent=True, + ) + y = y_rel.tolist() + y_se = y_se_rel.tolist() + + # x can be metric for a metric or arm names + if x_axis_var is None: + x = arm_names + x_se = None + else: + x_metric_key = "y_hat" if x_axis_var.pred else "y" + x_sd_key = "se_hat" if x_axis_var.pred else "se" + x = [getattr(a, x_metric_key).get(x_axis_var.metric_name, np.nan) for a in arms] + x_se = [getattr(a, x_sd_key).get(x_axis_var.metric_name, np.nan) for a in arms] + + if x_axis_var.rel: + # Delta method if relative to status quo arm + x_rel, x_se_rel = relativize( + means_t=x, + sems_t=x_se, + mean_c=getattr(status_quo_arm, x_metric_key).get( + x_axis_var.metric_name + ), + sem_c=getattr(status_quo_arm, x_sd_key).get(x_axis_var.metric_name), + as_percent=True, + ) + x = x_rel.tolist() + x_se = x_se_rel.tolist() + return x, x_se, y, y_se + + +def _error_scatter_trace( + arms: Sequence[Union[PlotInSampleArm, PlotOutOfSampleArm]], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, + y_axis_label: Optional[str] = None, + x_axis_label: Optional[str] = None, + status_quo_arm: Optional[PlotInSampleArm] = None, + show_CI: bool = True, + name: str = "In-sample", + color: Tuple[int] = COLORS.STEELBLUE.value, + visible: bool = True, + legendgroup: Optional[str] = None, + showlegend: bool = True, + hoverinfo: str = "text", + show_arm_details_on_hover: bool = True, + show_context: bool = False, + arm_noun: str = "arm", + color_parameter: Optional[str] = None, + color_metric: Optional[str] = None, +) -> Dict[str, Any]: + """Plot scatterplot with error bars. + + Args: + arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]): + a list of in-sample or out-of-sample arms. + In-sample arms have observed data, while out-of-sample arms + just have predicted data. As a result, + when passing out-of-sample arms, pred must be True. + y_axis_var: name of metric for y-axis, along with whether + it is observed or predicted. + x_axis_var: name of metric for x-axis, + along with whether it is observed or predicted. If None, arm names + are automatically used. + y_axis_label: custom label to use for y axis. + If None, use metric name from `y_axis_var`. + x_axis_label: custom label to use for x axis. + If None, use metric name from `x_axis_var` if that is not None. + status_quo_arm: the status quo + arm. Necessary for relative metrics. + show_CI: if True, plot confidence intervals. + name: name of trace. Default is "In-sample". + color: color as rgb tuple. Default is + (128, 177, 211), which corresponds to COLORS.STEELBLUE. + visible: if True, trace is visible (default). + legendgroup: group for legends. + showlegend: if True, legend if rendered. + hoverinfo: information to show on hover. Default is + custom text. + show_arm_details_on_hover: if True, display + parameterizations of arms on hover. Default is True. + show_context: if True and show_arm_details_on_hover, + context will be included in the hover. + arm_noun: noun to use instead of "arm" (e.g. group) + color_parameter: color points according to the specified parameter, + cannot be used together with color_metric. + color_metric: color points according to the specified metric, + cannot be used together with color_parameter. + """ + if color_metric and color_parameter: + raise RuntimeError( + "color_metric and color_parameter cannot be used at the same time!" + ) + + if (color_metric or color_parameter) and not all( + isinstance(arm, PlotInSampleArm) for arm in arms + ): + raise RuntimeError("Color coding currently only works with in-sample arms!") + + # Opportunistically sort if arm names are in {trial}_{arm} format + arms = sorted(arms, key=lambda a: arm_name_to_sort_key(a.name), reverse=True) + + x, x_se, y, y_se = _error_scatter_data( + arms=arms, + y_axis_var=y_axis_var, + x_axis_var=x_axis_var, + status_quo_arm=status_quo_arm, + ) + labels = [] + colors = [] + + arm_names = [a.name for a in arms] + + # No relativization if no x variable. + rel_x = x_axis_var.rel if x_axis_var else False + rel_y = y_axis_var.rel + + for i in range(len(arm_names)): + heading = f"{arm_noun.title()} {arm_names[i]}
" + x_lab = ( + "{name}: {estimate}{perc} {ci}
".format( + name=x_axis_var.metric_name if x_axis_label is None else x_axis_label, + estimate=( + round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i] + ), + ci="" if x_se is None else _format_CI(x[i], x_se[i], rel_x), + perc="%" if rel_x else "", + ) + if x_axis_var is not None + else "" + ) + y_lab = "{name}: {estimate}{perc} {ci}
".format( + name=y_axis_var.metric_name if y_axis_label is None else y_axis_label, + estimate=( + round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i] + ), + ci="" if y_se is None else _format_CI(y[i], y_se[i], rel_y), + perc="%" if rel_y else "", + ) + + parameterization = ( + _format_dict(arms[i].parameters, "Parameterization") + if show_arm_details_on_hover + else "" + ) + + if color_parameter: + colors.append(arms[i].parameters[color_parameter]) + elif color_metric: + arm = arms[i] + if isinstance(arm, PlotInSampleArm): + colors.append(arm.y[color_metric]) + + arm = arms[i] + context = ( + # Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous + # parameter to call `ax.analysis.helpers.plot_helpers._format_dict` but got + # `Optional[Dict[str, Union[float, str]]]`. + # pyre-fixme[6]: + _format_dict(arms[i].context_stratum, "Context") + if show_arm_details_on_hover + and show_context # noqa W503 + and arms[i].context_stratum # noqa W503 + else "" + ) + + labels.append( + "{arm_name}
{xlab}{ylab}{param_blob}{context}".format( + arm_name=heading, + xlab=x_lab, + ylab=y_lab, + param_blob=parameterization, + context=context, + ) + ) + i += 1 + + if color_metric or color_parameter: + rgba_blue_scale = [rgba(c) for c in BLUE_SCALE] + marker = { + "color": colors, + "colorscale": rgba_blue_scale, + "colorbar": {"title": color_metric or color_parameter}, + "showscale": True, + } + else: + marker = {"color": rgba(color)} + + trace = go.Scatter( + x=x, + y=y, + marker=marker, + mode="markers", + name=name, + text=labels, + hoverinfo=hoverinfo, + ) + + if show_CI: + if x_se is not None: + trace.update( + error_x={ + "type": "data", + "array": np.multiply(x_se, Z), + "color": rgba(color, CI_OPACITY), + } + ) + if y_se is not None: + trace.update( + error_y={ + "type": "data", + "array": np.multiply(y_se, Z), + "color": rgba(color, CI_OPACITY), + } + ) + if visible is not None: + trace.update(visible=visible) + if legendgroup is not None: + trace.update(legendgroup=legendgroup) + if showlegend is not None: + trace.update(showlegend=showlegend) + return trace diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py new file mode 100644 index 00000000000..bf5dea9f445 --- /dev/null +++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import plotly.graph_objects as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.cross_validation_helpers import ( + get_cv_plot_data, + get_min_max_with_errors, + obs_vs_pred_dropdown_plot, +) + +from ax.exceptions.core import UnsupportedPlotError +from ax.modelbridge.cross_validation import cross_validate +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestCrossValidationHelpers(TestCase): + @fast_botorch_optimize + def setUp(self) -> None: + exp = get_branin_experiment(with_batch=True) + exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + self.exp_status_quo = get_branin_experiment( + with_batch=True, with_status_quo=True + ) + self.exp_status_quo.trials[0].run() + self.model_status_quo = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + def test_get_min_max_with_errors(self) -> None: + # Test with sample data + x = [1.0, 2.0, 3.0] + y = [4.0, 5.0, 6.0] + sd_x = [0.1, 0.2, 0.3] + sd_y = [0.1, 0.2, 0.3] + min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y) + + expected_min = 1.0 - 0.1 * Z + expected_max = 6.0 + 0.3 * Z + # Check that the returned values are correct + print(f"min: {min_} {expected_min=}") + print(f"max: {max_} {expected_max=}") + self.assertAlmostEqual(min_, expected_min, delta=1e-4) + self.assertAlmostEqual(max_, expected_max, delta=1e-4) + + def test_obs_vs_pred_dropdown_plot(self) -> None: + for autoset_axis_limits in [False, True]: + for show_context in [False, True]: + # Assert that each type of plot can be constructed successfully + print(f"{autoset_axis_limits=} {show_context=}") + + cv_results = cross_validate(self.model) + + label_dict = {"branin": "BrAnIn"} + + data = get_cv_plot_data(cv_results, label_dict=label_dict) + fig = obs_vs_pred_dropdown_plot( + data=data, + rel=False, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + + self.assertIsInstance(fig, go.Figure) + + def test_obs_vs_pred_dropdown_plot_relative_effects(self) -> None: + for autoset_axis_limits in [False, True]: + for show_context in [False, True]: + # Assert that each type of plot can be constructed successfully + print(f"{autoset_axis_limits=} {show_context=}") + + cv_results = cross_validate(self.model_status_quo) + + label_dict = {"branin": "BrAnIn"} + + data = get_cv_plot_data(cv_results, label_dict=label_dict) + in_sample = data.in_sample + data_value = next(iter(in_sample.values())) + in_sample[self.exp_status_quo.status_quo.name] = data_value + data = data._replace( + in_sample=in_sample, + status_quo_name=self.exp_status_quo.status_quo.name, + ) + + if show_context: + with self.assertRaisesRegex( + UnsupportedPlotError, + "This plot does not support both context" + " and relativization at the same time", + ): + _ = obs_vs_pred_dropdown_plot( + data=data, + rel=True, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + continue + fig = obs_vs_pred_dropdown_plot( + data=data, + rel=True, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + + self.assertIsInstance(fig, go.Figure) diff --git a/ax/analysis/helpers/tests/test_scatter_helpers.py b/ax/analysis/helpers/tests/test_scatter_helpers.py new file mode 100644 index 00000000000..d2317ee7644 --- /dev/null +++ b/ax/analysis/helpers/tests/test_scatter_helpers.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestScatterHelpers(TestCase): + @fast_botorch_optimize + def setUp(self) -> None: + exp = get_branin_experiment(with_batch=True) + exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + self.exp_status_quo = get_branin_experiment( + with_batch=True, with_status_quo=True + ) + self.exp_status_quo.trials[0].run() + self.model_status_quo = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + def test_error_scatter_data(self) -> None: + """ + stub + """ + + def test_error_scatter_trace(self) -> None: + """ + stub + """