diff --git a/ax/analysis/helpers/color_helpers.py b/ax/analysis/helpers/color_helpers.py new file mode 100644 index 00000000000..575b641f014 --- /dev/null +++ b/ax/analysis/helpers/color_helpers.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from numbers import Real +from typing import List, Tuple + +# type aliases +TRGB = Tuple[Real, ...] + + +def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str: + """Convert RGB tuple to an RGBA string.""" + return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha) + + +def plotly_color_scale( + list_of_rgb_tuples: List[TRGB], + reverse: bool = False, + alpha: float = 1, +) -> List[Tuple[float, str]]: + """Convert list of RGB tuples to list of tuples, where each tuple is + break in [0, 1] and stringified RGBA color. + """ + if reverse: + list_of_rgb_tuples = list_of_rgb_tuples[::-1] + return [ + (round(i / (len(list_of_rgb_tuples) - 1), 3), rgba(rgb)) + for i, rgb in enumerate(list_of_rgb_tuples) + ] diff --git a/ax/analysis/helpers/constants.py b/ax/analysis/helpers/constants.py new file mode 100644 index 00000000000..c5af778ebac --- /dev/null +++ b/ax/analysis/helpers/constants.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import enum + +# Constants used for numerous plots +CI_OPACITY = 0.4 +DECIMALS = 3 +Z = 1.96 + + +# color constants used for plotting +class COLORS(enum.Enum): + STEELBLUE = (128, 177, 211) + CORAL = (251, 128, 114) + TEAL = (141, 211, 199) + PINK = (188, 128, 189) + LIGHT_PURPLE = (190, 186, 218) + ORANGE = (253, 180, 98) + + +# colors to be used for plotting discrete series +# pyre-fixme[5]: Global expression must be annotated. +DISCRETE_COLOR_SCALE = [ + COLORS.STEELBLUE.value, + COLORS.CORAL.value, + COLORS.PINK.value, + COLORS.LIGHT_PURPLE.value, + COLORS.ORANGE.value, + COLORS.TEAL.value, +] + +# 11-class PiYG from ColorBrewer (for contour plots) +GREEN_PINK_SCALE = [ + (142, 1, 82), + (197, 27, 125), + (222, 119, 174), + (241, 182, 218), + (253, 224, 239), + (247, 247, 247), + (230, 245, 208), + (184, 225, 134), + (127, 188, 65), + (77, 146, 33), + (39, 100, 25), +] +GREEN_SCALE = [ + (247, 252, 253), + (229, 245, 249), + (204, 236, 230), + (153, 216, 201), + (102, 194, 164), + (65, 174, 118), + (35, 139, 69), + (0, 109, 44), + (0, 68, 27), +] +BLUE_SCALE = [ + (255, 247, 251), + (236, 231, 242), + (208, 209, 230), + (166, 189, 219), + (116, 169, 207), + (54, 144, 192), + (5, 112, 176), + (3, 78, 123), +] +# 24 Class Mixed Color Palette +# Source: https://graphicdesign.stackexchange.com/a/3815 +MIXED_SCALE = [ + (2, 63, 165), + (125, 135, 185), + (190, 193, 212), + (214, 188, 192), + (187, 119, 132), + (142, 6, 59), + (74, 111, 227), + (133, 149, 225), + (181, 187, 227), + (230, 175, 185), + (224, 123, 145), + (211, 63, 106), + (17, 198, 56), + (141, 213, 147), + (198, 222, 199), + (234, 211, 198), + (240, 185, 141), + (239, 151, 8), + (15, 207, 192), + (156, 222, 214), + (213, 234, 231), + (243, 225, 235), + (246, 196, 225), + (247, 156, 212), +] diff --git a/ax/analysis/helpers/cross_validation_helpers.py b/ax/analysis/helpers/cross_validation_helpers.py new file mode 100644 index 00000000000..b6fac5b73fa --- /dev/null +++ b/ax/analysis/helpers/cross_validation_helpers.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import plotly.graph_objs as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.layout_helpers import layout_format, updatemenus_format +from ax.analysis.helpers.scatter_helpers import ( + _error_scatter_data, + _error_scatter_trace, + PlotData, + PlotInSampleArm, + PlotMetric, +) + +from ax.exceptions.core import UnsupportedPlotError + +from ax.modelbridge.cross_validation import CVResult + + +# Helper functions for plotting model fits +def get_min_max_with_errors( + x: List[float], y: List[float], sd_x: List[float], sd_y: List[float] +) -> Tuple[float, float]: + """Get min and max of a bivariate dataset (across variables). + + Args: + x: point estimate of x variable. + y: point estimate of y variable. + sd_x: standard deviation of x variable. + sd_y: standard deviation of y variable. + + Returns: + min_: minimum of points, including uncertainty. + max_: maximum of points, including uncertainty. + + """ + min_ = min( + min(np.array(x) - np.multiply(sd_x, Z)), min(np.array(y) - np.multiply(sd_y, Z)) + ) + max_ = max( + max(np.array(x) + np.multiply(sd_x, Z)), max(np.array(y) + np.multiply(sd_y, Z)) + ) + return min_, max_ + + +def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]: + """Diagonal line trace from (min_, min_) to (max_, max_). + + Args: + min_: minimum to be used for starting point of line. + max_: maximum to be used for ending point of line. + visible: if True, trace is set to visible. + + """ + return go.Scatter( + x=[min_, max_], + y=[min_, max_], + line=dict(color="black", width=2, dash="dot"), # noqa: C408 + mode="lines", + hoverinfo="none", + visible=visible, + showlegend=False, + ) + + +def default_value_se_raw(se_raw: Optional[List[float]], out_length: int) -> List[float]: + """ + Takes a list of standard errors and maps edge cases to default list + of floats. + + """ + new_se_raw = ( + [0.0 if np.isnan(se) else se for se in se_raw] + if se_raw is not None + else [0.0] * out_length + ) + return new_se_raw + + +def obs_vs_pred_dropdown_plot( + data: PlotData, + rel: bool, + show_context: bool = False, + xlabel: str = "Actual Outcome", + ylabel: str = "Predicted Outcome", + autoset_axis_limits: bool = True, +) -> go.Figure: + """Plot a dropdown plot of observed vs. predicted values from a model. + NOTE: If relative, cannot show additional context + on arm details on hover + + Args: + data: a name tuple storing observed and predicted data + from a model. + rel: if True, plot metrics relative to the status quo. + show_context: Show arm detail context on hover. + xlabel: Label for x-axis. + ylabel: Label for y-axis. + autoset_axis_limits: Automatically try to set the limit for each axis to focus + on the region of interest. + """ + traces = [] + metric_dropdown = [] + layout_axis_range = [] + status_quo_arm = None + if rel and data.status_quo_name is not None: + if show_context: + raise UnsupportedPlotError( + "This plot does not support both context and relativization at " + "the same time." + ) + status_quo_arm = data.in_sample[data.status_quo_name] + + for i, metric in enumerate(data.metrics): + y_raw, se_raw, y_hat, se_hat = _error_scatter_data( + list(data.in_sample.values()), + y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel), + x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel), + status_quo_arm=status_quo_arm, + ) + se_raw = default_value_se_raw(se_raw=se_raw, out_length=len(y_raw)) + + min_, max_ = get_min_max_with_errors(x=y_raw, y=y_hat, sd_x=se_raw, sd_y=se_hat) + if autoset_axis_limits: + y_raw_np = np.array(y_raw) + # TODO: replace interpolation->method once it becomes standard. + q1 = np.nanpercentile(y_raw_np, q=25, interpolation="lower").min() + q3 = np.nanpercentile(y_raw_np, q=75, interpolation="higher").max() + y_lower = q1 - 1.5 * (q3 - q1) + y_upper = q3 + 1.5 * (q3 - q1) + y_raw_np = y_raw_np.clip(y_lower, y_upper).tolist() + min_robust, max_robust = get_min_max_with_errors( + x=y_raw_np, y=y_hat, sd_x=se_raw, sd_y=se_hat + ) + y_padding = 0.05 * (max_robust - min_robust) + # Use the min/max of the limits + layout_axis_range.append( + [max(min_robust, min_) - y_padding, min(max_robust, max_) + y_padding] + ) + traces.append( + diagonal_trace( + min(min_robust, min_) - y_padding, + max(max_robust, max_) + y_padding, + visible=(i == 0), + ) + ) + else: + layout_axis_range.append(None) + traces.append(diagonal_trace(min_, max_, visible=(i == 0))) + + traces.append( + _error_scatter_trace( + arms=list(data.in_sample.values()), + hoverinfo="text", + show_arm_details_on_hover=True, + show_CI=True, + show_context=show_context, + status_quo_arm=status_quo_arm, + visible=(i == 0), + x_axis_label=xlabel, + x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel), + y_axis_label=ylabel, + y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel), + ) + ) + + # only the first two traces are visible (corresponding to first outcome + # in dropdown) + is_visible = [False] * (len(data.metrics) * 2) + is_visible[2 * i] = True + is_visible[2 * i + 1] = True + + # on dropdown change, restyle + metric_dropdown.append( + { + "args": [ + {"visible": is_visible}, + { + "xaxis.range": layout_axis_range[-1], + "yaxis.range": layout_axis_range[-1], + }, + ], + "label": metric, + "method": "update", + } + ) + + updatemenus = updatemenus_format(metric_dropdown=metric_dropdown) + layout = layout_format( + layout_axis_range_value=layout_axis_range[0], + xlabel=xlabel, + ylabel=ylabel, + updatemenus=updatemenus, + ) + + return go.Figure(data=traces, layout=layout) + + +def get_cv_plot_data( + cv_results: List[CVResult], label_dict: Optional[Dict[str, str]] +) -> PlotData: + """Construct PlotData from cv_results, mapping observed to y and se, + and predicted to y_hat and se_hat. + + Args: + cv_results: A CVResult for each observation in the training data. + label_dict: optional map from real metric names to shortened names + + Returns: + PlotData with the following fields: + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + PlotInSample arm have the fields + { + "name" + "y" + "se" + "parameters" + "y_hat" + "se_hat" + "context_stratum": (Always NONE) + } + out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]] + (NOTE: Always None) + status_quo_name: Optional[str] + (NOTE: Always None) + + """ + if len(cv_results) == 0: + return PlotData( + metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None + ) + + if label_dict is None: + label_dict = {} + # Apply label_dict to cv_results + cv_results = deepcopy(cv_results) # Copy and edit in-place + for cv_i in cv_results: + cv_i.observed.data.metric_names = [ + label_dict.get(m, m) for m in cv_i.observed.data.metric_names + ] + cv_i.predicted.metric_names = [ + label_dict.get(m, m) for m in cv_i.predicted.metric_names + ] + + # arm_name -> Arm data + insample_data: Dict[str, PlotInSampleArm] = {} + + # Get the union of all metric_names seen in predictions + metric_names = list( + set().union(*(cv_result.predicted.metric_names for cv_result in cv_results)) + ) + + for rid, cv_result in enumerate(cv_results): + arm_name = cv_result.observed.arm_name + arm_data = { + "name": cv_result.observed.arm_name, + "y": {}, + "se": {}, + "parameters": cv_result.observed.features.parameters, + "y_hat": {}, + "se_hat": {}, + "context_stratum": None, + } + for i, mname in enumerate(cv_result.observed.data.metric_names): + # pyre-fixme[16]: Optional type has no attribute `__setitem__`. + arm_data["y"][mname] = cv_result.observed.data.means[i] + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i]) + for i, mname in enumerate(cv_result.predicted.metric_names): + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["y_hat"][mname] = cv_result.predicted.means[i] + # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, + # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` + # has no attribute `__setitem__`. + arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i]) + + # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got + # `Optional[str]`. + # pyre-fixme[6]: + insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data) + return PlotData( + metrics=metric_names, + in_sample=insample_data, + out_of_sample=None, + status_quo_name=None, + ) diff --git a/ax/analysis/helpers/layout_helpers.py b/ax/analysis/helpers/layout_helpers.py new file mode 100644 index 00000000000..ccde6993377 --- /dev/null +++ b/ax/analysis/helpers/layout_helpers.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Type + +import plotly.graph_objs as go + + +def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + { + "x": 0, + "y": 1.125, + "yanchor": "top", + "xanchor": "left", + "buttons": metric_dropdown, + }, + { + "buttons": [ + { + "args": [ + { + "error_x.width": 4, + "error_x.thickness": 2, + "error_y.width": 4, + "error_y.thickness": 2, + } + ], + "label": "Yes", + "method": "restyle", + }, + { + "args": [ + { + "error_x.width": 0, + "error_x.thickness": 0, + "error_y.width": 0, + "error_y.thickness": 0, + } + ], + "label": "No", + "method": "restyle", + }, + ], + "x": 1.125, + "xanchor": "left", + "y": 0.8, + "yanchor": "middle", + }, + ] + + +def layout_format( + layout_axis_range_value: Tuple[float, float], + xlabel: str, + ylabel: str, + updatemenus: List[Dict[str, Any]], +) -> Type[go.graph_objs.Figure]: + layout = go.Layout( + annotations=[ + { + "showarrow": False, + "text": "Show CI", + "x": 1.125, + "xanchor": "left", + "xref": "paper", + "y": 0.9, + "yanchor": "middle", + "yref": "paper", + } + ], + xaxis={ + "range": layout_axis_range_value, + "title": xlabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + yaxis={ + "range": layout_axis_range_value, + "title": ylabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + showlegend=False, + hovermode="closest", + updatemenus=updatemenus, + width=530, + height=500, + ) + return layout diff --git a/ax/analysis/helpers/plot_helpers.py b/ax/analysis/helpers/plot_helpers.py new file mode 100644 index 00000000000..60aab0373c2 --- /dev/null +++ b/ax/analysis/helpers/plot_helpers.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from logging import Logger +from typing import Dict, List, Optional, Tuple, Union + +from ax.analysis.helpers.constants import DECIMALS, Z + +from ax.core.generator_run import GeneratorRun + +from ax.core.types import TParameterization +from ax.utils.common.logger import get_logger + + +logger: Logger = get_logger(__name__) + +# Typing alias +RawData = List[Dict[str, Union[str, float]]] + +TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]] + + +def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str: + """Format a dictionary for labels. + + Args: + param_dict: Dictionary to be formatted + name: String name of the thing being formatted. + + Returns: stringified blob. + """ + if len(param_dict) >= 10: + blob = "{} has too many items to render on hover ({}).".format( + name, len(param_dict) + ) + else: + blob = "
{}:
{}".format( + name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items()) + ) + return blob + + +def _format_CI(estimate: float, sd: float, relative: bool, zval: float = Z) -> str: + """Format confidence intervals given estimate and standard deviation. + + Args: + estimate: point estimate. + sd: standard deviation of point estimate. + relative: if True, '%' is appended. + zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs) + + Returns: formatted confidence interval. + """ + return "[{lb:.{digits}f}{perc}, {ub:.{digits}f}{perc}]".format( + lb=estimate - zval * sd, + ub=estimate + zval * sd, + digits=DECIMALS, + perc="%" if relative else "", + ) + + +def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]: + """Parses arm name into tuple suitable for reverse sorting by key + + Example: + arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] + sorted(arm_names, key=arm_name_to_sort_key, reverse=True) + ["control", "0_0", "1_2", "1_10", "10_0"] + """ + try: + trial_index, arm_index = arm_name.split("_") + return ("", -int(trial_index), -int(arm_index)) + except (ValueError, IndexError): + return (arm_name, 0, 0) diff --git a/ax/analysis/helpers/scatter_helpers.py b/ax/analysis/helpers/scatter_helpers.py new file mode 100644 index 00000000000..62c9431d4fa --- /dev/null +++ b/ax/analysis/helpers/scatter_helpers.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numbers + +from typing import ( + Any, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +import numpy as np +import plotly.graph_objs as go +from ax.analysis.helpers.color_helpers import rgba + +from ax.analysis.helpers.constants import BLUE_SCALE, CI_OPACITY, COLORS, DECIMALS, Z + +from ax.analysis.helpers.plot_helpers import ( + _format_CI, + _format_dict, + arm_name_to_sort_key, +) + +from ax.core.types import TParameterization + +from ax.exceptions.core import UnsupportedPlotError + +from ax.utils.stats.statstools import relativize + + +# Structs for plot data +class PlotMetric(NamedTuple): + """Struct for metric""" + + metric_name: str + pred: bool + rel: bool + + +class PlotInSampleArm(NamedTuple): + """Struct for in-sample arms (both observed and predicted data)""" + + name: str + parameters: TParameterization + y: Dict[str, float] + y_hat: Dict[str, float] + se: Dict[str, float] + se_hat: Dict[str, float] + context_stratum: Optional[Dict[str, Union[str, float]]] + + +class PlotOutOfSampleArm(NamedTuple): + """Struct for out-of-sample arms (only predicted data)""" + + name: str + parameters: TParameterization + y_hat: Dict[str, float] + se_hat: Dict[str, float] + context_stratum: Optional[Dict[str, Union[str, float]]] + + +class PlotData(NamedTuple): + """Struct for plot data, including both in-sample and out-of-sample arms""" + + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]] + status_quo_name: Optional[str] + + +def _error_scatter_data( + arms: Iterable[Union[PlotInSampleArm, PlotOutOfSampleArm]], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, + status_quo_arm: Optional[PlotInSampleArm] = None, +) -> Tuple[List[float], Optional[List[float]], List[float], List[float]]: + y_metric_key = "y_hat" if y_axis_var.pred else "y" + y_sd_key = "se_hat" if y_axis_var.pred else "se" + + arm_names = [a.name for a in arms] + y = [getattr(a, y_metric_key).get(y_axis_var.metric_name, np.nan) for a in arms] + y_se = [getattr(a, y_sd_key).get(y_axis_var.metric_name, np.nan) for a in arms] + + # Delta method if relative to status quo arm + if y_axis_var.rel: + if status_quo_arm is None: + raise UnsupportedPlotError( + "`status_quo_arm` cannot be None for relative effects." + ) + y_rel, y_se_rel = relativize( + means_t=y, + sems_t=y_se, + mean_c=getattr(status_quo_arm, y_metric_key).get(y_axis_var.metric_name), + sem_c=getattr(status_quo_arm, y_sd_key).get(y_axis_var.metric_name), + as_percent=True, + ) + y = y_rel.tolist() + y_se = y_se_rel.tolist() + + # x can be metric for a metric or arm names + if x_axis_var is None: + x = arm_names + x_se = None + else: + x_metric_key = "y_hat" if x_axis_var.pred else "y" + x_sd_key = "se_hat" if x_axis_var.pred else "se" + x = [getattr(a, x_metric_key).get(x_axis_var.metric_name, np.nan) for a in arms] + x_se = [getattr(a, x_sd_key).get(x_axis_var.metric_name, np.nan) for a in arms] + + if x_axis_var.rel: + # Delta method if relative to status quo arm + x_rel, x_se_rel = relativize( + means_t=x, + sems_t=x_se, + mean_c=getattr(status_quo_arm, x_metric_key).get( + x_axis_var.metric_name + ), + sem_c=getattr(status_quo_arm, x_sd_key).get(x_axis_var.metric_name), + as_percent=True, + ) + x = x_rel.tolist() + x_se = x_se_rel.tolist() + return x, x_se, y, y_se + + +def _error_scatter_trace( + arms: Sequence[Union[PlotInSampleArm, PlotOutOfSampleArm]], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, + y_axis_label: Optional[str] = None, + x_axis_label: Optional[str] = None, + status_quo_arm: Optional[PlotInSampleArm] = None, + show_CI: bool = True, + name: str = "In-sample", + color: Tuple[int] = COLORS.STEELBLUE.value, + visible: bool = True, + legendgroup: Optional[str] = None, + showlegend: bool = True, + hoverinfo: str = "text", + show_arm_details_on_hover: bool = True, + show_context: bool = False, + arm_noun: str = "arm", + color_parameter: Optional[str] = None, + color_metric: Optional[str] = None, +) -> Dict[str, Any]: + """Plot scatterplot with error bars. + + Args: + arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]): + a list of in-sample or out-of-sample arms. + In-sample arms have observed data, while out-of-sample arms + just have predicted data. As a result, + when passing out-of-sample arms, pred must be True. + y_axis_var: name of metric for y-axis, along with whether + it is observed or predicted. + x_axis_var: name of metric for x-axis, + along with whether it is observed or predicted. If None, arm names + are automatically used. + y_axis_label: custom label to use for y axis. + If None, use metric name from `y_axis_var`. + x_axis_label: custom label to use for x axis. + If None, use metric name from `x_axis_var` if that is not None. + status_quo_arm: the status quo + arm. Necessary for relative metrics. + show_CI: if True, plot confidence intervals. + name: name of trace. Default is "In-sample". + color: color as rgb tuple. Default is + (128, 177, 211), which corresponds to COLORS.STEELBLUE. + visible: if True, trace is visible (default). + legendgroup: group for legends. + showlegend: if True, legend if rendered. + hoverinfo: information to show on hover. Default is + custom text. + show_arm_details_on_hover: if True, display + parameterizations of arms on hover. Default is True. + show_context: if True and show_arm_details_on_hover, + context will be included in the hover. + arm_noun: noun to use instead of "arm" (e.g. group) + color_parameter: color points according to the specified parameter, + cannot be used together with color_metric. + color_metric: color points according to the specified metric, + cannot be used together with color_parameter. + """ + if color_metric and color_parameter: + raise RuntimeError( + "color_metric and color_parameter cannot be used at the same time!" + ) + + if (color_metric or color_parameter) and not all( + isinstance(arm, PlotInSampleArm) for arm in arms + ): + raise RuntimeError("Color coding currently only works with in-sample arms!") + + # Opportunistically sort if arm names are in {trial}_{arm} format + arms = sorted(arms, key=lambda a: arm_name_to_sort_key(a.name), reverse=True) + + x, x_se, y, y_se = _error_scatter_data( + arms=arms, + y_axis_var=y_axis_var, + x_axis_var=x_axis_var, + status_quo_arm=status_quo_arm, + ) + labels = [] + colors = [] + + arm_names = [a.name for a in arms] + + # No relativization if no x variable. + rel_x = x_axis_var.rel if x_axis_var else False + rel_y = y_axis_var.rel + + for i in range(len(arm_names)): + heading = f"{arm_noun.title()} {arm_names[i]}
" + x_lab = ( + "{name}: {estimate}{perc} {ci}
".format( + name=x_axis_var.metric_name if x_axis_label is None else x_axis_label, + estimate=( + round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i] + ), + ci="" if x_se is None else _format_CI(x[i], x_se[i], rel_x), + perc="%" if rel_x else "", + ) + if x_axis_var is not None + else "" + ) + y_lab = "{name}: {estimate}{perc} {ci}
".format( + name=y_axis_var.metric_name if y_axis_label is None else y_axis_label, + estimate=( + round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i] + ), + ci="" if y_se is None else _format_CI(y[i], y_se[i], rel_y), + perc="%" if rel_y else "", + ) + + parameterization = ( + _format_dict(arms[i].parameters, "Parameterization") + if show_arm_details_on_hover + else "" + ) + + if color_parameter: + colors.append(arms[i].parameters[color_parameter]) + elif color_metric: + arm = arms[i] + if isinstance(arm, PlotInSampleArm): + colors.append(arm.y[color_metric]) + + arm = arms[i] + context = ( + # Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous + # parameter to call `ax.analysis.helpers.plot_helpers._format_dict` but got + # `Optional[Dict[str, Union[float, str]]]`. + # pyre-fixme[6]: + _format_dict(arms[i].context_stratum, "Context") + if show_arm_details_on_hover + and show_context # noqa W503 + and arms[i].context_stratum # noqa W503 + else "" + ) + + labels.append( + "{arm_name}
{xlab}{ylab}{param_blob}{context}".format( + arm_name=heading, + xlab=x_lab, + ylab=y_lab, + param_blob=parameterization, + context=context, + ) + ) + i += 1 + + if color_metric or color_parameter: + rgba_blue_scale = [rgba(c) for c in BLUE_SCALE] + marker = { + "color": colors, + "colorscale": rgba_blue_scale, + "colorbar": {"title": color_metric or color_parameter}, + "showscale": True, + } + else: + marker = {"color": rgba(color)} + + trace = go.Scatter( + x=x, + y=y, + marker=marker, + mode="markers", + name=name, + text=labels, + hoverinfo=hoverinfo, + ) + + if show_CI: + if x_se is not None: + trace.update( + error_x={ + "type": "data", + "array": np.multiply(x_se, Z), + "color": rgba(color, CI_OPACITY), + } + ) + if y_se is not None: + trace.update( + error_y={ + "type": "data", + "array": np.multiply(y_se, Z), + "color": rgba(color, CI_OPACITY), + } + ) + if visible is not None: + trace.update(visible=visible) + if legendgroup is not None: + trace.update(legendgroup=legendgroup) + if showlegend is not None: + trace.update(showlegend=showlegend) + return trace diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py new file mode 100644 index 00000000000..bf5dea9f445 --- /dev/null +++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import plotly.graph_objects as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.cross_validation_helpers import ( + get_cv_plot_data, + get_min_max_with_errors, + obs_vs_pred_dropdown_plot, +) + +from ax.exceptions.core import UnsupportedPlotError +from ax.modelbridge.cross_validation import cross_validate +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestCrossValidationHelpers(TestCase): + @fast_botorch_optimize + def setUp(self) -> None: + exp = get_branin_experiment(with_batch=True) + exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + self.exp_status_quo = get_branin_experiment( + with_batch=True, with_status_quo=True + ) + self.exp_status_quo.trials[0].run() + self.model_status_quo = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + def test_get_min_max_with_errors(self) -> None: + # Test with sample data + x = [1.0, 2.0, 3.0] + y = [4.0, 5.0, 6.0] + sd_x = [0.1, 0.2, 0.3] + sd_y = [0.1, 0.2, 0.3] + min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y) + + expected_min = 1.0 - 0.1 * Z + expected_max = 6.0 + 0.3 * Z + # Check that the returned values are correct + print(f"min: {min_} {expected_min=}") + print(f"max: {max_} {expected_max=}") + self.assertAlmostEqual(min_, expected_min, delta=1e-4) + self.assertAlmostEqual(max_, expected_max, delta=1e-4) + + def test_obs_vs_pred_dropdown_plot(self) -> None: + for autoset_axis_limits in [False, True]: + for show_context in [False, True]: + # Assert that each type of plot can be constructed successfully + print(f"{autoset_axis_limits=} {show_context=}") + + cv_results = cross_validate(self.model) + + label_dict = {"branin": "BrAnIn"} + + data = get_cv_plot_data(cv_results, label_dict=label_dict) + fig = obs_vs_pred_dropdown_plot( + data=data, + rel=False, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + + self.assertIsInstance(fig, go.Figure) + + def test_obs_vs_pred_dropdown_plot_relative_effects(self) -> None: + for autoset_axis_limits in [False, True]: + for show_context in [False, True]: + # Assert that each type of plot can be constructed successfully + print(f"{autoset_axis_limits=} {show_context=}") + + cv_results = cross_validate(self.model_status_quo) + + label_dict = {"branin": "BrAnIn"} + + data = get_cv_plot_data(cv_results, label_dict=label_dict) + in_sample = data.in_sample + data_value = next(iter(in_sample.values())) + in_sample[self.exp_status_quo.status_quo.name] = data_value + data = data._replace( + in_sample=in_sample, + status_quo_name=self.exp_status_quo.status_quo.name, + ) + + if show_context: + with self.assertRaisesRegex( + UnsupportedPlotError, + "This plot does not support both context" + " and relativization at the same time", + ): + _ = obs_vs_pred_dropdown_plot( + data=data, + rel=True, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + continue + fig = obs_vs_pred_dropdown_plot( + data=data, + rel=True, + show_context=show_context, + autoset_axis_limits=autoset_axis_limits, + ) + + self.assertIsInstance(fig, go.Figure) diff --git a/ax/analysis/helpers/tests/test_scatter_helpers.py b/ax/analysis/helpers/tests/test_scatter_helpers.py new file mode 100644 index 00000000000..d2317ee7644 --- /dev/null +++ b/ax/analysis/helpers/tests/test_scatter_helpers.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestScatterHelpers(TestCase): + @fast_botorch_optimize + def setUp(self) -> None: + exp = get_branin_experiment(with_batch=True) + exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + self.exp_status_quo = get_branin_experiment( + with_batch=True, with_status_quo=True + ) + self.exp_status_quo.trials[0].run() + self.model_status_quo = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + def test_error_scatter_data(self) -> None: + """ + stub + """ + + def test_error_scatter_trace(self) -> None: + """ + stub + """