From 2369b6d45a8e14ae8712ab0fd19924e1ba3c34b4 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Wed, 6 Mar 2024 10:51:08 -0800 Subject: [PATCH] copying "cross_validation_helper" code from ax.plot (#2249) Summary: To prepare of creating the "CrossValidationPlot" module, this change imports the dependent code from "ax.plots". It also cleans up the original code by breaking it apart into different helper files, and trimming out methods which are not used to create this plot. Some new unit tests are being added as well. ## In new ax.analysis CrossValidationPlot The function of the code is broken out neatly by function: Constants, string operations, basic formatting helpers - ax/analysis/helpers/constants.py - 21 lines - ax/analysis/helpers/color_helpers.py - 33 lines - ax/analysis/helpers/plot_helpers.py - 76 lines - ax/analysis/helpers/layout_helpers.py - 96 lines Plot Logic - ax/analysis/helpers/scatter_helpers.py - 180 lines - ax/analysis/helpers/cross_validation_helpers.py - 291 lines CrossValidationPlot object - ax/analysis/cross_validation_plot.py - 109 lines 806 total lines ## Required files from ax.plot needed to create cross validation plot - ax/plot/scatter.py - 1722 lines - ax/plot/diagnostic.py - 691 lines - ax/plot/helper - 995 lines - ax/plot/base.py - 94 lines - ax/plot/color.py - 120 lines 3622 total lines of code across the files which have the logic for cross validation plots Differential Revision: D54495372 --- ax/analysis/helpers/color_helpers.py | 33 ++ ax/analysis/helpers/constants.py | 21 ++ .../helpers/cross_validation_helpers.py | 291 ++++++++++++++++++ ax/analysis/helpers/layout_helpers.py | 96 ++++++ ax/analysis/helpers/plot_helpers.py | 76 +++++ ax/analysis/helpers/scatter_helpers.py | 180 +++++++++++ .../tests/test_cross_validation_helpers.py | 71 +++++ 7 files changed, 768 insertions(+) create mode 100644 ax/analysis/helpers/color_helpers.py create mode 100644 ax/analysis/helpers/constants.py create mode 100644 ax/analysis/helpers/cross_validation_helpers.py create mode 100644 ax/analysis/helpers/layout_helpers.py create mode 100644 ax/analysis/helpers/plot_helpers.py create mode 100644 ax/analysis/helpers/scatter_helpers.py create mode 100644 ax/analysis/helpers/tests/test_cross_validation_helpers.py diff --git a/ax/analysis/helpers/color_helpers.py b/ax/analysis/helpers/color_helpers.py new file mode 100644 index 00000000000..575b641f014 --- /dev/null +++ b/ax/analysis/helpers/color_helpers.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from numbers import Real +from typing import List, Tuple + +# type aliases +TRGB = Tuple[Real, ...] + + +def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str: + """Convert RGB tuple to an RGBA string.""" + return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha) + + +def plotly_color_scale( + list_of_rgb_tuples: List[TRGB], + reverse: bool = False, + alpha: float = 1, +) -> List[Tuple[float, str]]: + """Convert list of RGB tuples to list of tuples, where each tuple is + break in [0, 1] and stringified RGBA color. + """ + if reverse: + list_of_rgb_tuples = list_of_rgb_tuples[::-1] + return [ + (round(i / (len(list_of_rgb_tuples) - 1), 3), rgba(rgb)) + for i, rgb in enumerate(list_of_rgb_tuples) + ] diff --git a/ax/analysis/helpers/constants.py b/ax/analysis/helpers/constants.py new file mode 100644 index 00000000000..3c05e859dac --- /dev/null +++ b/ax/analysis/helpers/constants.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import enum + +# Constants used for numerous plots +CI_OPACITY = 0.4 +DECIMALS = 3 +Z = 1.96 + + +# color constants used for plotting +class COLORS(enum.Enum): + STEELBLUE = (128, 177, 211) + CORAL = (251, 128, 114) + TEAL = (141, 211, 199) + PINK = (188, 128, 189) + LIGHT_PURPLE = (190, 186, 218) + ORANGE = (253, 180, 98) diff --git a/ax/analysis/helpers/cross_validation_helpers.py b/ax/analysis/helpers/cross_validation_helpers.py new file mode 100644 index 00000000000..2be5db48878 --- /dev/null +++ b/ax/analysis/helpers/cross_validation_helpers.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import plotly.graph_objs as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.layout_helpers import layout_format, updatemenus_format +from ax.analysis.helpers.scatter_helpers import ( + _error_scatter_data, + _error_scatter_trace, + PlotData, + PlotInSampleArm, + PlotMetric, +) + +from ax.modelbridge.cross_validation import CVResult + + +# Helper functions for plotting model fits +def get_min_max_with_errors( + x: List[float], y: List[float], sd_x: List[float], sd_y: List[float] +) -> Tuple[float, float]: + """Get min and max of a bivariate dataset (across variables). + + Args: + x: point estimate of x variable. + y: point estimate of y variable. + sd_x: standard deviation of x variable. + sd_y: standard deviation of y variable. + + Returns: + min_: minimum of points, including uncertainty. + max_: maximum of points, including uncertainty. + + """ + min_ = min( + min(np.array(x) - np.multiply(sd_x, Z)), min(np.array(y) - np.multiply(sd_y, Z)) + ) + max_ = max( + max(np.array(x) + np.multiply(sd_x, Z)), max(np.array(y) + np.multiply(sd_y, Z)) + ) + return min_, max_ + + +def get_plotting_limit_ignore_outliers( + x: List[float], y: List[float], sd_x: List[float], sd_y: List[float] +) -> Tuple[float, float]: + """Get a range for a bivarite dataset based on the 25th and 75th percentiles + Used as plotting limit to ignore outliers. + + Args: + x: point estimate of x variable. + y: point estimate of y variable. + sd_x: standard deviation of x variable. + sd_y: standard deviation of y variable. + + Returns: + min: lower bound of range + max: higher bound of range + + """ + min_, max_ = get_min_max_with_errors(x=x, y=y, sd_x=sd_x, sd_y=sd_y) + + x_np = np.array(x) + # TODO: replace interpolation->method once it becomes standard. + q1 = np.nanpercentile(x_np, q=25, interpolation="lower").min() + q3 = np.nanpercentile(x_np, q=75, interpolation="higher").max() + quartile_difference = q3 - q1 + + y_lower = q1 - 1.5 * quartile_difference + y_upper = q3 + 1.5 * quartile_difference + + # clip outliers from x + x_np = x_np.clip(y_lower, y_upper).tolist() + min_robust, max_robust = get_min_max_with_errors(x=x_np, y=y, sd_x=sd_x, sd_y=sd_y) + y_padding = 0.05 * (max_robust - min_robust) + + return (max(min_robust, min_) - y_padding, min(max_robust, max_) + y_padding) + + +def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]: + """Diagonal line trace from (min_, min_) to (max_, max_). + + Args: + min_: minimum to be used for starting point of line. + max_: maximum to be used for ending point of line. + visible: if True, trace is set to visible. + """ + return go.Scatter( + x=[min_, max_], + y=[min_, max_], + line=dict(color="black", width=2, dash="dot"), # noqa: C408 + mode="lines", + hoverinfo="none", + visible=visible, + showlegend=False, + ) + + +def default_value_se_raw(se_raw: Optional[List[float]], out_length: int) -> List[float]: + """ + Takes a list of standard errors and maps edge cases to default list + of floats. + + """ + new_se_raw = ( + [0.0 if np.isnan(se) else se for se in se_raw] + if se_raw is not None + else [0.0] * out_length + ) + return new_se_raw + + +def obs_vs_pred_dropdown_plot( + data: PlotData, + xlabel: str = "Actual Outcome", + ylabel: str = "Predicted Outcome", +) -> go.Figure: + """Plot a dropdown plot of observed vs. predicted values from a model. + + Args: + data: a name tuple storing observed and predicted data + from a model. + xlabel: Label for x-axis. + ylabel: Label for y-axis. + """ + traces = [] + metric_dropdown = [] + layout_axis_range = [] + + for i, metric in enumerate(data.metrics): + y_raw, se_raw, y_hat, se_hat = _error_scatter_data( + list(data.in_sample.values()), + y_axis_var=PlotMetric(metric_name=metric, pred=True), + x_axis_var=PlotMetric(metric_name=metric, pred=False), + ) + se_raw = default_value_se_raw(se_raw=se_raw, out_length=len(y_raw)) + + # Use the min/max of the limits + min_, max_ = get_plotting_limit_ignore_outliers( + x=y_raw, y=y_hat, sd_x=se_raw, sd_y=se_hat + ) + layout_axis_range.append([min_, max_]) + traces.append( + diagonal_trace( + min_, + max_, + visible=(i == 0), + ) + ) + + traces.append( + _error_scatter_trace( + arms=list(data.in_sample.values()), + show_CI=True, + x_axis_label=xlabel, + x_axis_var=PlotMetric(metric_name=metric, pred=False), + y_axis_label=ylabel, + y_axis_var=PlotMetric(metric_name=metric, pred=True), + ) + ) + + # only the first two traces are visible (corresponding to first outcome + # in dropdown) + is_visible = [False] * (len(data.metrics) * 2) + is_visible[2 * i] = True + is_visible[2 * i + 1] = True + + # on dropdown change, restyle + metric_dropdown.append( + { + "args": [ + {"visible": is_visible}, + { + "xaxis.range": layout_axis_range[-1], + "yaxis.range": layout_axis_range[-1], + }, + ], + "label": metric, + "method": "update", + } + ) + + updatemenus = updatemenus_format(metric_dropdown=metric_dropdown) + layout = layout_format( + layout_axis_range_value=layout_axis_range[0], + xlabel=xlabel, + ylabel=ylabel, + updatemenus=updatemenus, + ) + + return go.Figure(data=traces, layout=layout) + + +def remap_label( + cv_results: List[CVResult], label_dict: Dict[str, str] +) -> List[CVResult]: + """Remaps labels in cv_results according to label_dict. + + Args: + cv_results: A CVResult for each observation in the training data. + label_dict: optional map from real metric names to shortened names + + Returns: + A CVResult for each observation in the training data. + """ + cv_results = deepcopy(cv_results) # Copy and edit in-place + for cv_i in cv_results: + cv_i.observed.data.metric_names = [ + label_dict.get(m, m) for m in cv_i.observed.data.metric_names + ] + cv_i.predicted.metric_names = [ + label_dict.get(m, m) for m in cv_i.predicted.metric_names + ] + return cv_results + + +def get_cv_plot_data( + cv_results: List[CVResult], label_dict: Optional[Dict[str, str]] +) -> PlotData: + """Construct PlotData from cv_results, mapping observed to y and se, + and predicted to y_hat and se_hat. + + Args: + cv_results: A CVResult for each observation in the training data. + label_dict: optional map from real metric names to shortened names + + Returns: + PlotData with the following fields: + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + PlotInSample arm have the fields + { + "name" + "y" + "se" + "parameters" + "y_hat" + "se_hat" + } + + """ + if len(cv_results) == 0: + return PlotData(metrics=[], in_sample={}) + + if label_dict: + cv_results = remap_label(cv_results=cv_results, label_dict=label_dict) + + # arm_name -> Arm data + insample_data: Dict[str, PlotInSampleArm] = {} + + # Get the union of all metric_names seen in predictions + metric_names = list( + set().union(*(cv_result.predicted.metric_names for cv_result in cv_results)) + ) + + for rid, cv_result in enumerate(cv_results): + arm_name = cv_result.observed.arm_name + y, se, y_hat, se_hat = {}, {}, {}, {} + + arm_data = { + "name": cv_result.observed.arm_name, + "y": y, + "se": se, + "parameters": cv_result.observed.features.parameters, + "y_hat": y_hat, + "se_hat": se_hat, + } + for i, mname in enumerate(cv_result.observed.data.metric_names): + y[mname] = cv_result.observed.data.means[i] + se[mname] = np.sqrt(cv_result.observed.data.covariance[i][i]) + for i, mname in enumerate(cv_result.predicted.metric_names): + y_hat[mname] = cv_result.predicted.means[i] + se_hat[mname] = np.sqrt(cv_result.predicted.covariance[i][i]) + + # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got + # `Optional[str]`. + # pyre-fixme[6]: + insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data) + return PlotData( + metrics=metric_names, + in_sample=insample_data, + ) diff --git a/ax/analysis/helpers/layout_helpers.py b/ax/analysis/helpers/layout_helpers.py new file mode 100644 index 00000000000..e77903cab35 --- /dev/null +++ b/ax/analysis/helpers/layout_helpers.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Type + +import plotly.graph_objs as go + + +def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + { + "x": 0, + "y": 1.125, + "yanchor": "top", + "xanchor": "left", + "buttons": metric_dropdown, + }, + { + "buttons": [ + { + "args": [ + { + "error_x.width": 4, + "error_x.thickness": 2, + "error_y.width": 4, + "error_y.thickness": 2, + } + ], + "label": "Yes", + "method": "restyle", + }, + { + "args": [ + { + "error_x.width": 0, + "error_x.thickness": 0, + "error_y.width": 0, + "error_y.thickness": 0, + } + ], + "label": "No", + "method": "restyle", + }, + ], + "x": 1.125, + "xanchor": "left", + "y": 0.8, + "yanchor": "middle", + }, + ] + + +def layout_format( + layout_axis_range_value: Tuple[float, float], + xlabel: str, + ylabel: str, + updatemenus: List[Dict[str, Any]], +) -> Type[go.Figure]: + layout = go.Layout( + annotations=[ + { + "showarrow": False, + "text": "Show CI", + "x": 1.125, + "xanchor": "left", + "xref": "paper", + "y": 0.9, + "yanchor": "middle", + "yref": "paper", + } + ], + xaxis={ + "range": layout_axis_range_value, + "title": xlabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + yaxis={ + "range": layout_axis_range_value, + "title": ylabel, + "zeroline": False, + "mirror": True, + "linecolor": "black", + "linewidth": 0.5, + }, + showlegend=False, + hovermode="closest", + updatemenus=updatemenus, + width=530, + height=500, + ) + return layout diff --git a/ax/analysis/helpers/plot_helpers.py b/ax/analysis/helpers/plot_helpers.py new file mode 100644 index 00000000000..80713402bb2 --- /dev/null +++ b/ax/analysis/helpers/plot_helpers.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from logging import Logger +from typing import Dict, List, Optional, Tuple, Union + +from ax.analysis.helpers.constants import DECIMALS, Z + +from ax.core.generator_run import GeneratorRun + +from ax.core.types import TParameterization +from ax.utils.common.logger import get_logger + + +logger: Logger = get_logger(__name__) + +# Typing alias +RawData = List[Dict[str, Union[str, float]]] + +TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]] + + +def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str: + """Format a dictionary for labels. + + Args: + param_dict: Dictionary to be formatted + name: String name of the thing being formatted. + + Returns: stringified blob. + """ + if len(param_dict) >= 10: + blob = "{} has too many items to render on hover ({}).".format( + name, len(param_dict) + ) + else: + blob = "
{}:
{}".format( + name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items()) + ) + return blob + + +def _format_CI(estimate: float, sd: float, zval: float = Z) -> str: + """Format confidence intervals given estimate and standard deviation. + + Args: + estimate: point estimate. + sd: standard deviation of point estimate. + zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs) + + Returns: formatted confidence interval. + """ + return "[{lb:.{digits}f}, {ub:.{digits}f}]".format( + lb=estimate - zval * sd, + ub=estimate + zval * sd, + digits=DECIMALS, + ) + + +def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]: + """Parses arm name into tuple suitable for reverse sorting by key + + Example: + arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] + sorted(arm_names, key=arm_name_to_sort_key, reverse=True) + ["control", "0_0", "1_2", "1_10", "10_0"] + """ + try: + trial_index, arm_index = arm_name.split("_") + return ("", -int(trial_index), -int(arm_index)) + except (ValueError, IndexError): + return (arm_name, 0, 0) diff --git a/ax/analysis/helpers/scatter_helpers.py b/ax/analysis/helpers/scatter_helpers.py new file mode 100644 index 00000000000..f0230066b79 --- /dev/null +++ b/ax/analysis/helpers/scatter_helpers.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numbers + +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple + +import numpy as np +import plotly.graph_objs as go +from ax.analysis.helpers.color_helpers import rgba + +from ax.analysis.helpers.constants import CI_OPACITY, COLORS, DECIMALS, Z + +from ax.analysis.helpers.plot_helpers import ( + _format_CI, + _format_dict, + arm_name_to_sort_key, +) + +from ax.core.types import TParameterization + + +# Structs for plot data +class PlotMetric(NamedTuple): + """Struct for metric""" + + metric_name: str + pred: bool + + +class PlotInSampleArm(NamedTuple): + """Struct for in-sample arms (both observed and predicted data)""" + + name: str + parameters: TParameterization + y: Dict[str, float] + y_hat: Dict[str, float] + se: Dict[str, float] + se_hat: Dict[str, float] + + +class PlotData(NamedTuple): + """Struct for plot data, including metrics and in-sample arms""" + + metrics: List[str] + in_sample: Dict[str, PlotInSampleArm] + + +def _error_scatter_data( + arms: Iterable[PlotInSampleArm], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, +) -> Tuple[List[float], Optional[List[float]], List[float], List[float]]: + y_metric_key = "y_hat" if y_axis_var.pred else "y" + y_sd_key = "se_hat" if y_axis_var.pred else "se" + + arm_names = [a.name for a in arms] + y = [getattr(a, y_metric_key).get(y_axis_var.metric_name, np.nan) for a in arms] + y_se = [getattr(a, y_sd_key).get(y_axis_var.metric_name, np.nan) for a in arms] + + # x can be metric for a metric or arm names + if x_axis_var is None: + x = arm_names + x_se = None + else: + x_metric_key = "y_hat" if x_axis_var.pred else "y" + x_sd_key = "se_hat" if x_axis_var.pred else "se" + x = [getattr(a, x_metric_key).get(x_axis_var.metric_name, np.nan) for a in arms] + x_se = [getattr(a, x_sd_key).get(x_axis_var.metric_name, np.nan) for a in arms] + + return x, x_se, y, y_se + + +def _error_scatter_trace( + arms: Sequence[PlotInSampleArm], + y_axis_var: PlotMetric, + x_axis_var: Optional[PlotMetric] = None, + y_axis_label: Optional[str] = None, + x_axis_label: Optional[str] = None, + show_CI: bool = True, +) -> Dict[str, Any]: + """Plot scatterplot with error bars. + + Args: + arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]): + a list of in-sample or out-of-sample arms. + In-sample arms have observed data, while out-of-sample arms + just have predicted data. As a result, + when passing out-of-sample arms, pred must be True. + y_axis_var: name of metric for y-axis, along with whether + it is observed or predicted. + x_axis_var: name of metric for x-axis, + along with whether it is observed or predicted. If None, arm names + are automatically used. + y_axis_label: custom label to use for y axis. + If None, use metric name from `y_axis_var`. + x_axis_label: custom label to use for x axis. + If None, use metric name from `x_axis_var` if that is not None. + show_CI: if True, plot confidence intervals. + """ + + # Opportunistically sort if arm names are in {trial}_{arm} format + arms = sorted(arms, key=lambda a: arm_name_to_sort_key(a.name), reverse=True) + + x, x_se, y, y_se = _error_scatter_data( + arms=arms, + y_axis_var=y_axis_var, + x_axis_var=x_axis_var, + ) + labels = [] + + arm_names = [a.name for a in arms] + + for i in range(len(arm_names)): + heading = f"Arm {arm_names[i]}
" + x_lab = ( + "{name}: {estimate} {ci}
".format( + name=x_axis_var.metric_name if x_axis_label is None else x_axis_label, + estimate=( + round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i] + ), + ci="" if x_se is None else _format_CI(x[i], x_se[i]), + ) + if x_axis_var is not None + else "" + ) + y_lab = "{name}: {estimate} {ci}
".format( + name=y_axis_var.metric_name if y_axis_label is None else y_axis_label, + estimate=( + round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i] + ), + ci="" if y_se is None else _format_CI(y[i], y_se[i]), + ) + + parameterization = _format_dict(arms[i].parameters, "Parameterization") + + labels.append( + "{arm_name}
{xlab}{ylab}{param_blob}".format( + arm_name=heading, + xlab=x_lab, + ylab=y_lab, + param_blob=parameterization, + ) + ) + i += 1 + + trace = go.Scatter( + x=x, + y=y, + marker={"color": rgba(COLORS.STEELBLUE.value)}, + mode="markers", + name="In-sample", + text=labels, + hoverinfo="text", + ) + + if show_CI: + if x_se is not None: + trace.update( + error_x={ + "type": "data", + "array": np.multiply(x_se, Z), + "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY), + } + ) + if y_se is not None: + trace.update( + error_y={ + "type": "data", + "array": np.multiply(y_se, Z), + "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY), + } + ) + + trace.update(visible=True) + trace.update(showlegend=True) + return trace diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py new file mode 100644 index 00000000000..0f091ccdaeb --- /dev/null +++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import plotly.graph_objects as go + +from ax.analysis.helpers.constants import Z + +from ax.analysis.helpers.cross_validation_helpers import ( + get_cv_plot_data, + get_min_max_with_errors, + obs_vs_pred_dropdown_plot, +) + +from ax.modelbridge.cross_validation import cross_validate +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestCrossValidationHelpers(TestCase): + @fast_botorch_optimize + def setUp(self) -> None: + exp = get_branin_experiment(with_batch=True) + exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + self.exp_status_quo = get_branin_experiment( + with_batch=True, with_status_quo=True + ) + self.exp_status_quo.trials[0].run() + self.model_status_quo = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=exp, + data=exp.fetch_data(), + ) + + def test_get_min_max_with_errors(self) -> None: + # Test with sample data + x = [1.0, 2.0, 3.0] + y = [4.0, 5.0, 6.0] + sd_x = [0.1, 0.2, 0.3] + sd_y = [0.1, 0.2, 0.3] + min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y) + + expected_min = 1.0 - 0.1 * Z + expected_max = 6.0 + 0.3 * Z + # Check that the returned values are correct + print(f"min: {min_} {expected_min=}") + print(f"max: {max_} {expected_max=}") + self.assertAlmostEqual(min_, expected_min, delta=1e-4) + self.assertAlmostEqual(max_, expected_max, delta=1e-4) + + def test_obs_vs_pred_dropdown_plot(self) -> None: + cv_results = cross_validate(self.model) + + label_dict = {"branin": "BrAnIn"} + + data = get_cv_plot_data(cv_results, label_dict=label_dict) + fig = obs_vs_pred_dropdown_plot( + data=data, + ) + + self.assertIsInstance(fig, go.Figure)