diff --git a/ax/analysis/helpers/color_helpers.py b/ax/analysis/helpers/color_helpers.py
new file mode 100644
index 00000000000..575b641f014
--- /dev/null
+++ b/ax/analysis/helpers/color_helpers.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from numbers import Real
+from typing import List, Tuple
+
+# type aliases
+TRGB = Tuple[Real, ...]
+
+
+def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str:
+ """Convert RGB tuple to an RGBA string."""
+ return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha)
+
+
+def plotly_color_scale(
+ list_of_rgb_tuples: List[TRGB],
+ reverse: bool = False,
+ alpha: float = 1,
+) -> List[Tuple[float, str]]:
+ """Convert list of RGB tuples to list of tuples, where each tuple is
+ break in [0, 1] and stringified RGBA color.
+ """
+ if reverse:
+ list_of_rgb_tuples = list_of_rgb_tuples[::-1]
+ return [
+ (round(i / (len(list_of_rgb_tuples) - 1), 3), rgba(rgb))
+ for i, rgb in enumerate(list_of_rgb_tuples)
+ ]
diff --git a/ax/analysis/helpers/constants.py b/ax/analysis/helpers/constants.py
new file mode 100644
index 00000000000..c5af778ebac
--- /dev/null
+++ b/ax/analysis/helpers/constants.py
@@ -0,0 +1,97 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import enum
+
+# Constants used for numerous plots
+CI_OPACITY = 0.4
+DECIMALS = 3
+Z = 1.96
+
+
+# color constants used for plotting
+class COLORS(enum.Enum):
+ STEELBLUE = (128, 177, 211)
+ CORAL = (251, 128, 114)
+ TEAL = (141, 211, 199)
+ PINK = (188, 128, 189)
+ LIGHT_PURPLE = (190, 186, 218)
+ ORANGE = (253, 180, 98)
+
+
+# colors to be used for plotting discrete series
+# pyre-fixme[5]: Global expression must be annotated.
+DISCRETE_COLOR_SCALE = [
+ COLORS.STEELBLUE.value,
+ COLORS.CORAL.value,
+ COLORS.PINK.value,
+ COLORS.LIGHT_PURPLE.value,
+ COLORS.ORANGE.value,
+ COLORS.TEAL.value,
+]
+
+# 11-class PiYG from ColorBrewer (for contour plots)
+GREEN_PINK_SCALE = [
+ (142, 1, 82),
+ (197, 27, 125),
+ (222, 119, 174),
+ (241, 182, 218),
+ (253, 224, 239),
+ (247, 247, 247),
+ (230, 245, 208),
+ (184, 225, 134),
+ (127, 188, 65),
+ (77, 146, 33),
+ (39, 100, 25),
+]
+GREEN_SCALE = [
+ (247, 252, 253),
+ (229, 245, 249),
+ (204, 236, 230),
+ (153, 216, 201),
+ (102, 194, 164),
+ (65, 174, 118),
+ (35, 139, 69),
+ (0, 109, 44),
+ (0, 68, 27),
+]
+BLUE_SCALE = [
+ (255, 247, 251),
+ (236, 231, 242),
+ (208, 209, 230),
+ (166, 189, 219),
+ (116, 169, 207),
+ (54, 144, 192),
+ (5, 112, 176),
+ (3, 78, 123),
+]
+# 24 Class Mixed Color Palette
+# Source: https://graphicdesign.stackexchange.com/a/3815
+MIXED_SCALE = [
+ (2, 63, 165),
+ (125, 135, 185),
+ (190, 193, 212),
+ (214, 188, 192),
+ (187, 119, 132),
+ (142, 6, 59),
+ (74, 111, 227),
+ (133, 149, 225),
+ (181, 187, 227),
+ (230, 175, 185),
+ (224, 123, 145),
+ (211, 63, 106),
+ (17, 198, 56),
+ (141, 213, 147),
+ (198, 222, 199),
+ (234, 211, 198),
+ (240, 185, 141),
+ (239, 151, 8),
+ (15, 207, 192),
+ (156, 222, 214),
+ (213, 234, 231),
+ (243, 225, 235),
+ (246, 196, 225),
+ (247, 156, 212),
+]
diff --git a/ax/analysis/helpers/cross_validation_helpers.py b/ax/analysis/helpers/cross_validation_helpers.py
new file mode 100644
index 00000000000..b6fac5b73fa
--- /dev/null
+++ b/ax/analysis/helpers/cross_validation_helpers.py
@@ -0,0 +1,300 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import plotly.graph_objs as go
+
+from ax.analysis.helpers.constants import Z
+
+from ax.analysis.helpers.layout_helpers import layout_format, updatemenus_format
+from ax.analysis.helpers.scatter_helpers import (
+ _error_scatter_data,
+ _error_scatter_trace,
+ PlotData,
+ PlotInSampleArm,
+ PlotMetric,
+)
+
+from ax.exceptions.core import UnsupportedPlotError
+
+from ax.modelbridge.cross_validation import CVResult
+
+
+# Helper functions for plotting model fits
+def get_min_max_with_errors(
+ x: List[float], y: List[float], sd_x: List[float], sd_y: List[float]
+) -> Tuple[float, float]:
+ """Get min and max of a bivariate dataset (across variables).
+
+ Args:
+ x: point estimate of x variable.
+ y: point estimate of y variable.
+ sd_x: standard deviation of x variable.
+ sd_y: standard deviation of y variable.
+
+ Returns:
+ min_: minimum of points, including uncertainty.
+ max_: maximum of points, including uncertainty.
+
+ """
+ min_ = min(
+ min(np.array(x) - np.multiply(sd_x, Z)), min(np.array(y) - np.multiply(sd_y, Z))
+ )
+ max_ = max(
+ max(np.array(x) + np.multiply(sd_x, Z)), max(np.array(y) + np.multiply(sd_y, Z))
+ )
+ return min_, max_
+
+
+def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]:
+ """Diagonal line trace from (min_, min_) to (max_, max_).
+
+ Args:
+ min_: minimum to be used for starting point of line.
+ max_: maximum to be used for ending point of line.
+ visible: if True, trace is set to visible.
+
+ """
+ return go.Scatter(
+ x=[min_, max_],
+ y=[min_, max_],
+ line=dict(color="black", width=2, dash="dot"), # noqa: C408
+ mode="lines",
+ hoverinfo="none",
+ visible=visible,
+ showlegend=False,
+ )
+
+
+def default_value_se_raw(se_raw: Optional[List[float]], out_length: int) -> List[float]:
+ """
+ Takes a list of standard errors and maps edge cases to default list
+ of floats.
+
+ """
+ new_se_raw = (
+ [0.0 if np.isnan(se) else se for se in se_raw]
+ if se_raw is not None
+ else [0.0] * out_length
+ )
+ return new_se_raw
+
+
+def obs_vs_pred_dropdown_plot(
+ data: PlotData,
+ rel: bool,
+ show_context: bool = False,
+ xlabel: str = "Actual Outcome",
+ ylabel: str = "Predicted Outcome",
+ autoset_axis_limits: bool = True,
+) -> go.Figure:
+ """Plot a dropdown plot of observed vs. predicted values from a model.
+ NOTE: If relative, cannot show additional context
+ on arm details on hover
+
+ Args:
+ data: a name tuple storing observed and predicted data
+ from a model.
+ rel: if True, plot metrics relative to the status quo.
+ show_context: Show arm detail context on hover.
+ xlabel: Label for x-axis.
+ ylabel: Label for y-axis.
+ autoset_axis_limits: Automatically try to set the limit for each axis to focus
+ on the region of interest.
+ """
+ traces = []
+ metric_dropdown = []
+ layout_axis_range = []
+ status_quo_arm = None
+ if rel and data.status_quo_name is not None:
+ if show_context:
+ raise UnsupportedPlotError(
+ "This plot does not support both context and relativization at "
+ "the same time."
+ )
+ status_quo_arm = data.in_sample[data.status_quo_name]
+
+ for i, metric in enumerate(data.metrics):
+ y_raw, se_raw, y_hat, se_hat = _error_scatter_data(
+ list(data.in_sample.values()),
+ y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel),
+ x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel),
+ status_quo_arm=status_quo_arm,
+ )
+ se_raw = default_value_se_raw(se_raw=se_raw, out_length=len(y_raw))
+
+ min_, max_ = get_min_max_with_errors(x=y_raw, y=y_hat, sd_x=se_raw, sd_y=se_hat)
+ if autoset_axis_limits:
+ y_raw_np = np.array(y_raw)
+ # TODO: replace interpolation->method once it becomes standard.
+ q1 = np.nanpercentile(y_raw_np, q=25, interpolation="lower").min()
+ q3 = np.nanpercentile(y_raw_np, q=75, interpolation="higher").max()
+ y_lower = q1 - 1.5 * (q3 - q1)
+ y_upper = q3 + 1.5 * (q3 - q1)
+ y_raw_np = y_raw_np.clip(y_lower, y_upper).tolist()
+ min_robust, max_robust = get_min_max_with_errors(
+ x=y_raw_np, y=y_hat, sd_x=se_raw, sd_y=se_hat
+ )
+ y_padding = 0.05 * (max_robust - min_robust)
+ # Use the min/max of the limits
+ layout_axis_range.append(
+ [max(min_robust, min_) - y_padding, min(max_robust, max_) + y_padding]
+ )
+ traces.append(
+ diagonal_trace(
+ min(min_robust, min_) - y_padding,
+ max(max_robust, max_) + y_padding,
+ visible=(i == 0),
+ )
+ )
+ else:
+ layout_axis_range.append(None)
+ traces.append(diagonal_trace(min_, max_, visible=(i == 0)))
+
+ traces.append(
+ _error_scatter_trace(
+ arms=list(data.in_sample.values()),
+ hoverinfo="text",
+ show_arm_details_on_hover=True,
+ show_CI=True,
+ show_context=show_context,
+ status_quo_arm=status_quo_arm,
+ visible=(i == 0),
+ x_axis_label=xlabel,
+ x_axis_var=PlotMetric(metric_name=metric, pred=False, rel=rel),
+ y_axis_label=ylabel,
+ y_axis_var=PlotMetric(metric_name=metric, pred=True, rel=rel),
+ )
+ )
+
+ # only the first two traces are visible (corresponding to first outcome
+ # in dropdown)
+ is_visible = [False] * (len(data.metrics) * 2)
+ is_visible[2 * i] = True
+ is_visible[2 * i + 1] = True
+
+ # on dropdown change, restyle
+ metric_dropdown.append(
+ {
+ "args": [
+ {"visible": is_visible},
+ {
+ "xaxis.range": layout_axis_range[-1],
+ "yaxis.range": layout_axis_range[-1],
+ },
+ ],
+ "label": metric,
+ "method": "update",
+ }
+ )
+
+ updatemenus = updatemenus_format(metric_dropdown=metric_dropdown)
+ layout = layout_format(
+ layout_axis_range_value=layout_axis_range[0],
+ xlabel=xlabel,
+ ylabel=ylabel,
+ updatemenus=updatemenus,
+ )
+
+ return go.Figure(data=traces, layout=layout)
+
+
+def get_cv_plot_data(
+ cv_results: List[CVResult], label_dict: Optional[Dict[str, str]]
+) -> PlotData:
+ """Construct PlotData from cv_results, mapping observed to y and se,
+ and predicted to y_hat and se_hat.
+
+ Args:
+ cv_results: A CVResult for each observation in the training data.
+ label_dict: optional map from real metric names to shortened names
+
+ Returns:
+ PlotData with the following fields:
+ metrics: List[str]
+ in_sample: Dict[str, PlotInSampleArm]
+ PlotInSample arm have the fields
+ {
+ "name"
+ "y"
+ "se"
+ "parameters"
+ "y_hat"
+ "se_hat"
+ "context_stratum": (Always NONE)
+ }
+ out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]]
+ (NOTE: Always None)
+ status_quo_name: Optional[str]
+ (NOTE: Always None)
+
+ """
+ if len(cv_results) == 0:
+ return PlotData(
+ metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None
+ )
+
+ if label_dict is None:
+ label_dict = {}
+ # Apply label_dict to cv_results
+ cv_results = deepcopy(cv_results) # Copy and edit in-place
+ for cv_i in cv_results:
+ cv_i.observed.data.metric_names = [
+ label_dict.get(m, m) for m in cv_i.observed.data.metric_names
+ ]
+ cv_i.predicted.metric_names = [
+ label_dict.get(m, m) for m in cv_i.predicted.metric_names
+ ]
+
+ # arm_name -> Arm data
+ insample_data: Dict[str, PlotInSampleArm] = {}
+
+ # Get the union of all metric_names seen in predictions
+ metric_names = list(
+ set().union(*(cv_result.predicted.metric_names for cv_result in cv_results))
+ )
+
+ for rid, cv_result in enumerate(cv_results):
+ arm_name = cv_result.observed.arm_name
+ arm_data = {
+ "name": cv_result.observed.arm_name,
+ "y": {},
+ "se": {},
+ "parameters": cv_result.observed.features.parameters,
+ "y_hat": {},
+ "se_hat": {},
+ "context_stratum": None,
+ }
+ for i, mname in enumerate(cv_result.observed.data.metric_names):
+ # pyre-fixme[16]: Optional type has no attribute `__setitem__`.
+ arm_data["y"][mname] = cv_result.observed.data.means[i]
+ # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
+ # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
+ # has no attribute `__setitem__`.
+ arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i])
+ for i, mname in enumerate(cv_result.predicted.metric_names):
+ # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
+ # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
+ # has no attribute `__setitem__`.
+ arm_data["y_hat"][mname] = cv_result.predicted.means[i]
+ # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
+ # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
+ # has no attribute `__setitem__`.
+ arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i])
+
+ # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got
+ # `Optional[str]`.
+ # pyre-fixme[6]:
+ insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data)
+ return PlotData(
+ metrics=metric_names,
+ in_sample=insample_data,
+ out_of_sample=None,
+ status_quo_name=None,
+ )
diff --git a/ax/analysis/helpers/layout_helpers.py b/ax/analysis/helpers/layout_helpers.py
new file mode 100644
index 00000000000..ccde6993377
--- /dev/null
+++ b/ax/analysis/helpers/layout_helpers.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, List, Tuple, Type
+
+import plotly.graph_objs as go
+
+
+def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ return [
+ {
+ "x": 0,
+ "y": 1.125,
+ "yanchor": "top",
+ "xanchor": "left",
+ "buttons": metric_dropdown,
+ },
+ {
+ "buttons": [
+ {
+ "args": [
+ {
+ "error_x.width": 4,
+ "error_x.thickness": 2,
+ "error_y.width": 4,
+ "error_y.thickness": 2,
+ }
+ ],
+ "label": "Yes",
+ "method": "restyle",
+ },
+ {
+ "args": [
+ {
+ "error_x.width": 0,
+ "error_x.thickness": 0,
+ "error_y.width": 0,
+ "error_y.thickness": 0,
+ }
+ ],
+ "label": "No",
+ "method": "restyle",
+ },
+ ],
+ "x": 1.125,
+ "xanchor": "left",
+ "y": 0.8,
+ "yanchor": "middle",
+ },
+ ]
+
+
+def layout_format(
+ layout_axis_range_value: Tuple[float, float],
+ xlabel: str,
+ ylabel: str,
+ updatemenus: List[Dict[str, Any]],
+) -> Type[go.graph_objs.Figure]:
+ layout = go.Layout(
+ annotations=[
+ {
+ "showarrow": False,
+ "text": "Show CI",
+ "x": 1.125,
+ "xanchor": "left",
+ "xref": "paper",
+ "y": 0.9,
+ "yanchor": "middle",
+ "yref": "paper",
+ }
+ ],
+ xaxis={
+ "range": layout_axis_range_value,
+ "title": xlabel,
+ "zeroline": False,
+ "mirror": True,
+ "linecolor": "black",
+ "linewidth": 0.5,
+ },
+ yaxis={
+ "range": layout_axis_range_value,
+ "title": ylabel,
+ "zeroline": False,
+ "mirror": True,
+ "linecolor": "black",
+ "linewidth": 0.5,
+ },
+ showlegend=False,
+ hovermode="closest",
+ updatemenus=updatemenus,
+ width=530,
+ height=500,
+ )
+ return layout
diff --git a/ax/analysis/helpers/plot_helpers.py b/ax/analysis/helpers/plot_helpers.py
new file mode 100644
index 00000000000..60aab0373c2
--- /dev/null
+++ b/ax/analysis/helpers/plot_helpers.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from logging import Logger
+from typing import Dict, List, Optional, Tuple, Union
+
+from ax.analysis.helpers.constants import DECIMALS, Z
+
+from ax.core.generator_run import GeneratorRun
+
+from ax.core.types import TParameterization
+from ax.utils.common.logger import get_logger
+
+
+logger: Logger = get_logger(__name__)
+
+# Typing alias
+RawData = List[Dict[str, Union[str, float]]]
+
+TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]]
+
+
+def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str:
+ """Format a dictionary for labels.
+
+ Args:
+ param_dict: Dictionary to be formatted
+ name: String name of the thing being formatted.
+
+ Returns: stringified blob.
+ """
+ if len(param_dict) >= 10:
+ blob = "{} has too many items to render on hover ({}).".format(
+ name, len(param_dict)
+ )
+ else:
+ blob = "
{}:
{}".format(
+ name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items())
+ )
+ return blob
+
+
+def _format_CI(estimate: float, sd: float, relative: bool, zval: float = Z) -> str:
+ """Format confidence intervals given estimate and standard deviation.
+
+ Args:
+ estimate: point estimate.
+ sd: standard deviation of point estimate.
+ relative: if True, '%' is appended.
+ zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs)
+
+ Returns: formatted confidence interval.
+ """
+ return "[{lb:.{digits}f}{perc}, {ub:.{digits}f}{perc}]".format(
+ lb=estimate - zval * sd,
+ ub=estimate + zval * sd,
+ digits=DECIMALS,
+ perc="%" if relative else "",
+ )
+
+
+def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]:
+ """Parses arm name into tuple suitable for reverse sorting by key
+
+ Example:
+ arm_names = ["0_0", "1_10", "1_2", "10_0", "control"]
+ sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
+ ["control", "0_0", "1_2", "1_10", "10_0"]
+ """
+ try:
+ trial_index, arm_index = arm_name.split("_")
+ return ("", -int(trial_index), -int(arm_index))
+ except (ValueError, IndexError):
+ return (arm_name, 0, 0)
diff --git a/ax/analysis/helpers/scatter_helpers.py b/ax/analysis/helpers/scatter_helpers.py
new file mode 100644
index 00000000000..62c9431d4fa
--- /dev/null
+++ b/ax/analysis/helpers/scatter_helpers.py
@@ -0,0 +1,325 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numbers
+
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ NamedTuple,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+import plotly.graph_objs as go
+from ax.analysis.helpers.color_helpers import rgba
+
+from ax.analysis.helpers.constants import BLUE_SCALE, CI_OPACITY, COLORS, DECIMALS, Z
+
+from ax.analysis.helpers.plot_helpers import (
+ _format_CI,
+ _format_dict,
+ arm_name_to_sort_key,
+)
+
+from ax.core.types import TParameterization
+
+from ax.exceptions.core import UnsupportedPlotError
+
+from ax.utils.stats.statstools import relativize
+
+
+# Structs for plot data
+class PlotMetric(NamedTuple):
+ """Struct for metric"""
+
+ metric_name: str
+ pred: bool
+ rel: bool
+
+
+class PlotInSampleArm(NamedTuple):
+ """Struct for in-sample arms (both observed and predicted data)"""
+
+ name: str
+ parameters: TParameterization
+ y: Dict[str, float]
+ y_hat: Dict[str, float]
+ se: Dict[str, float]
+ se_hat: Dict[str, float]
+ context_stratum: Optional[Dict[str, Union[str, float]]]
+
+
+class PlotOutOfSampleArm(NamedTuple):
+ """Struct for out-of-sample arms (only predicted data)"""
+
+ name: str
+ parameters: TParameterization
+ y_hat: Dict[str, float]
+ se_hat: Dict[str, float]
+ context_stratum: Optional[Dict[str, Union[str, float]]]
+
+
+class PlotData(NamedTuple):
+ """Struct for plot data, including both in-sample and out-of-sample arms"""
+
+ metrics: List[str]
+ in_sample: Dict[str, PlotInSampleArm]
+ out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]]
+ status_quo_name: Optional[str]
+
+
+def _error_scatter_data(
+ arms: Iterable[Union[PlotInSampleArm, PlotOutOfSampleArm]],
+ y_axis_var: PlotMetric,
+ x_axis_var: Optional[PlotMetric] = None,
+ status_quo_arm: Optional[PlotInSampleArm] = None,
+) -> Tuple[List[float], Optional[List[float]], List[float], List[float]]:
+ y_metric_key = "y_hat" if y_axis_var.pred else "y"
+ y_sd_key = "se_hat" if y_axis_var.pred else "se"
+
+ arm_names = [a.name for a in arms]
+ y = [getattr(a, y_metric_key).get(y_axis_var.metric_name, np.nan) for a in arms]
+ y_se = [getattr(a, y_sd_key).get(y_axis_var.metric_name, np.nan) for a in arms]
+
+ # Delta method if relative to status quo arm
+ if y_axis_var.rel:
+ if status_quo_arm is None:
+ raise UnsupportedPlotError(
+ "`status_quo_arm` cannot be None for relative effects."
+ )
+ y_rel, y_se_rel = relativize(
+ means_t=y,
+ sems_t=y_se,
+ mean_c=getattr(status_quo_arm, y_metric_key).get(y_axis_var.metric_name),
+ sem_c=getattr(status_quo_arm, y_sd_key).get(y_axis_var.metric_name),
+ as_percent=True,
+ )
+ y = y_rel.tolist()
+ y_se = y_se_rel.tolist()
+
+ # x can be metric for a metric or arm names
+ if x_axis_var is None:
+ x = arm_names
+ x_se = None
+ else:
+ x_metric_key = "y_hat" if x_axis_var.pred else "y"
+ x_sd_key = "se_hat" if x_axis_var.pred else "se"
+ x = [getattr(a, x_metric_key).get(x_axis_var.metric_name, np.nan) for a in arms]
+ x_se = [getattr(a, x_sd_key).get(x_axis_var.metric_name, np.nan) for a in arms]
+
+ if x_axis_var.rel:
+ # Delta method if relative to status quo arm
+ x_rel, x_se_rel = relativize(
+ means_t=x,
+ sems_t=x_se,
+ mean_c=getattr(status_quo_arm, x_metric_key).get(
+ x_axis_var.metric_name
+ ),
+ sem_c=getattr(status_quo_arm, x_sd_key).get(x_axis_var.metric_name),
+ as_percent=True,
+ )
+ x = x_rel.tolist()
+ x_se = x_se_rel.tolist()
+ return x, x_se, y, y_se
+
+
+def _error_scatter_trace(
+ arms: Sequence[Union[PlotInSampleArm, PlotOutOfSampleArm]],
+ y_axis_var: PlotMetric,
+ x_axis_var: Optional[PlotMetric] = None,
+ y_axis_label: Optional[str] = None,
+ x_axis_label: Optional[str] = None,
+ status_quo_arm: Optional[PlotInSampleArm] = None,
+ show_CI: bool = True,
+ name: str = "In-sample",
+ color: Tuple[int] = COLORS.STEELBLUE.value,
+ visible: bool = True,
+ legendgroup: Optional[str] = None,
+ showlegend: bool = True,
+ hoverinfo: str = "text",
+ show_arm_details_on_hover: bool = True,
+ show_context: bool = False,
+ arm_noun: str = "arm",
+ color_parameter: Optional[str] = None,
+ color_metric: Optional[str] = None,
+) -> Dict[str, Any]:
+ """Plot scatterplot with error bars.
+
+ Args:
+ arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]):
+ a list of in-sample or out-of-sample arms.
+ In-sample arms have observed data, while out-of-sample arms
+ just have predicted data. As a result,
+ when passing out-of-sample arms, pred must be True.
+ y_axis_var: name of metric for y-axis, along with whether
+ it is observed or predicted.
+ x_axis_var: name of metric for x-axis,
+ along with whether it is observed or predicted. If None, arm names
+ are automatically used.
+ y_axis_label: custom label to use for y axis.
+ If None, use metric name from `y_axis_var`.
+ x_axis_label: custom label to use for x axis.
+ If None, use metric name from `x_axis_var` if that is not None.
+ status_quo_arm: the status quo
+ arm. Necessary for relative metrics.
+ show_CI: if True, plot confidence intervals.
+ name: name of trace. Default is "In-sample".
+ color: color as rgb tuple. Default is
+ (128, 177, 211), which corresponds to COLORS.STEELBLUE.
+ visible: if True, trace is visible (default).
+ legendgroup: group for legends.
+ showlegend: if True, legend if rendered.
+ hoverinfo: information to show on hover. Default is
+ custom text.
+ show_arm_details_on_hover: if True, display
+ parameterizations of arms on hover. Default is True.
+ show_context: if True and show_arm_details_on_hover,
+ context will be included in the hover.
+ arm_noun: noun to use instead of "arm" (e.g. group)
+ color_parameter: color points according to the specified parameter,
+ cannot be used together with color_metric.
+ color_metric: color points according to the specified metric,
+ cannot be used together with color_parameter.
+ """
+ if color_metric and color_parameter:
+ raise RuntimeError(
+ "color_metric and color_parameter cannot be used at the same time!"
+ )
+
+ if (color_metric or color_parameter) and not all(
+ isinstance(arm, PlotInSampleArm) for arm in arms
+ ):
+ raise RuntimeError("Color coding currently only works with in-sample arms!")
+
+ # Opportunistically sort if arm names are in {trial}_{arm} format
+ arms = sorted(arms, key=lambda a: arm_name_to_sort_key(a.name), reverse=True)
+
+ x, x_se, y, y_se = _error_scatter_data(
+ arms=arms,
+ y_axis_var=y_axis_var,
+ x_axis_var=x_axis_var,
+ status_quo_arm=status_quo_arm,
+ )
+ labels = []
+ colors = []
+
+ arm_names = [a.name for a in arms]
+
+ # No relativization if no x variable.
+ rel_x = x_axis_var.rel if x_axis_var else False
+ rel_y = y_axis_var.rel
+
+ for i in range(len(arm_names)):
+ heading = f"{arm_noun.title()} {arm_names[i]}
"
+ x_lab = (
+ "{name}: {estimate}{perc} {ci}
".format(
+ name=x_axis_var.metric_name if x_axis_label is None else x_axis_label,
+ estimate=(
+ round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i]
+ ),
+ ci="" if x_se is None else _format_CI(x[i], x_se[i], rel_x),
+ perc="%" if rel_x else "",
+ )
+ if x_axis_var is not None
+ else ""
+ )
+ y_lab = "{name}: {estimate}{perc} {ci}
".format(
+ name=y_axis_var.metric_name if y_axis_label is None else y_axis_label,
+ estimate=(
+ round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i]
+ ),
+ ci="" if y_se is None else _format_CI(y[i], y_se[i], rel_y),
+ perc="%" if rel_y else "",
+ )
+
+ parameterization = (
+ _format_dict(arms[i].parameters, "Parameterization")
+ if show_arm_details_on_hover
+ else ""
+ )
+
+ if color_parameter:
+ colors.append(arms[i].parameters[color_parameter])
+ elif color_metric:
+ arm = arms[i]
+ if isinstance(arm, PlotInSampleArm):
+ colors.append(arm.y[color_metric])
+
+ arm = arms[i]
+ context = (
+ # Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous
+ # parameter to call `ax.analysis.helpers.plot_helpers._format_dict` but got
+ # `Optional[Dict[str, Union[float, str]]]`.
+ # pyre-fixme[6]:
+ _format_dict(arms[i].context_stratum, "Context")
+ if show_arm_details_on_hover
+ and show_context # noqa W503
+ and arms[i].context_stratum # noqa W503
+ else ""
+ )
+
+ labels.append(
+ "{arm_name}
{xlab}{ylab}{param_blob}{context}".format(
+ arm_name=heading,
+ xlab=x_lab,
+ ylab=y_lab,
+ param_blob=parameterization,
+ context=context,
+ )
+ )
+ i += 1
+
+ if color_metric or color_parameter:
+ rgba_blue_scale = [rgba(c) for c in BLUE_SCALE]
+ marker = {
+ "color": colors,
+ "colorscale": rgba_blue_scale,
+ "colorbar": {"title": color_metric or color_parameter},
+ "showscale": True,
+ }
+ else:
+ marker = {"color": rgba(color)}
+
+ trace = go.Scatter(
+ x=x,
+ y=y,
+ marker=marker,
+ mode="markers",
+ name=name,
+ text=labels,
+ hoverinfo=hoverinfo,
+ )
+
+ if show_CI:
+ if x_se is not None:
+ trace.update(
+ error_x={
+ "type": "data",
+ "array": np.multiply(x_se, Z),
+ "color": rgba(color, CI_OPACITY),
+ }
+ )
+ if y_se is not None:
+ trace.update(
+ error_y={
+ "type": "data",
+ "array": np.multiply(y_se, Z),
+ "color": rgba(color, CI_OPACITY),
+ }
+ )
+ if visible is not None:
+ trace.update(visible=visible)
+ if legendgroup is not None:
+ trace.update(legendgroup=legendgroup)
+ if showlegend is not None:
+ trace.update(showlegend=showlegend)
+ return trace
diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py
new file mode 100644
index 00000000000..bf5dea9f445
--- /dev/null
+++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import plotly.graph_objects as go
+
+from ax.analysis.helpers.constants import Z
+
+from ax.analysis.helpers.cross_validation_helpers import (
+ get_cv_plot_data,
+ get_min_max_with_errors,
+ obs_vs_pred_dropdown_plot,
+)
+
+from ax.exceptions.core import UnsupportedPlotError
+from ax.modelbridge.cross_validation import cross_validate
+from ax.modelbridge.registry import Models
+from ax.utils.common.testutils import TestCase
+from ax.utils.testing.core_stubs import get_branin_experiment
+from ax.utils.testing.mock import fast_botorch_optimize
+
+
+class TestCrossValidationHelpers(TestCase):
+ @fast_botorch_optimize
+ def setUp(self) -> None:
+ exp = get_branin_experiment(with_batch=True)
+ exp.trials[0].run()
+ self.model = Models.BOTORCH_MODULAR(
+ # Model bridge kwargs
+ experiment=exp,
+ data=exp.fetch_data(),
+ )
+
+ self.exp_status_quo = get_branin_experiment(
+ with_batch=True, with_status_quo=True
+ )
+ self.exp_status_quo.trials[0].run()
+ self.model_status_quo = Models.BOTORCH_MODULAR(
+ # Model bridge kwargs
+ experiment=exp,
+ data=exp.fetch_data(),
+ )
+
+ def test_get_min_max_with_errors(self) -> None:
+ # Test with sample data
+ x = [1.0, 2.0, 3.0]
+ y = [4.0, 5.0, 6.0]
+ sd_x = [0.1, 0.2, 0.3]
+ sd_y = [0.1, 0.2, 0.3]
+ min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y)
+
+ expected_min = 1.0 - 0.1 * Z
+ expected_max = 6.0 + 0.3 * Z
+ # Check that the returned values are correct
+ print(f"min: {min_} {expected_min=}")
+ print(f"max: {max_} {expected_max=}")
+ self.assertAlmostEqual(min_, expected_min, delta=1e-4)
+ self.assertAlmostEqual(max_, expected_max, delta=1e-4)
+
+ def test_obs_vs_pred_dropdown_plot(self) -> None:
+ for autoset_axis_limits in [False, True]:
+ for show_context in [False, True]:
+ # Assert that each type of plot can be constructed successfully
+ print(f"{autoset_axis_limits=} {show_context=}")
+
+ cv_results = cross_validate(self.model)
+
+ label_dict = {"branin": "BrAnIn"}
+
+ data = get_cv_plot_data(cv_results, label_dict=label_dict)
+ fig = obs_vs_pred_dropdown_plot(
+ data=data,
+ rel=False,
+ show_context=show_context,
+ autoset_axis_limits=autoset_axis_limits,
+ )
+
+ self.assertIsInstance(fig, go.Figure)
+
+ def test_obs_vs_pred_dropdown_plot_relative_effects(self) -> None:
+ for autoset_axis_limits in [False, True]:
+ for show_context in [False, True]:
+ # Assert that each type of plot can be constructed successfully
+ print(f"{autoset_axis_limits=} {show_context=}")
+
+ cv_results = cross_validate(self.model_status_quo)
+
+ label_dict = {"branin": "BrAnIn"}
+
+ data = get_cv_plot_data(cv_results, label_dict=label_dict)
+ in_sample = data.in_sample
+ data_value = next(iter(in_sample.values()))
+ in_sample[self.exp_status_quo.status_quo.name] = data_value
+ data = data._replace(
+ in_sample=in_sample,
+ status_quo_name=self.exp_status_quo.status_quo.name,
+ )
+
+ if show_context:
+ with self.assertRaisesRegex(
+ UnsupportedPlotError,
+ "This plot does not support both context"
+ " and relativization at the same time",
+ ):
+ _ = obs_vs_pred_dropdown_plot(
+ data=data,
+ rel=True,
+ show_context=show_context,
+ autoset_axis_limits=autoset_axis_limits,
+ )
+ continue
+ fig = obs_vs_pred_dropdown_plot(
+ data=data,
+ rel=True,
+ show_context=show_context,
+ autoset_axis_limits=autoset_axis_limits,
+ )
+
+ self.assertIsInstance(fig, go.Figure)
diff --git a/ax/analysis/helpers/tests/test_scatter_helpers.py b/ax/analysis/helpers/tests/test_scatter_helpers.py
new file mode 100644
index 00000000000..d2317ee7644
--- /dev/null
+++ b/ax/analysis/helpers/tests/test_scatter_helpers.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ax.modelbridge.registry import Models
+from ax.utils.common.testutils import TestCase
+from ax.utils.testing.core_stubs import get_branin_experiment
+from ax.utils.testing.mock import fast_botorch_optimize
+
+
+class TestScatterHelpers(TestCase):
+ @fast_botorch_optimize
+ def setUp(self) -> None:
+ exp = get_branin_experiment(with_batch=True)
+ exp.trials[0].run()
+ self.model = Models.BOTORCH_MODULAR(
+ # Model bridge kwargs
+ experiment=exp,
+ data=exp.fetch_data(),
+ )
+
+ self.exp_status_quo = get_branin_experiment(
+ with_batch=True, with_status_quo=True
+ )
+ self.exp_status_quo.trials[0].run()
+ self.model_status_quo = Models.BOTORCH_MODULAR(
+ # Model bridge kwargs
+ experiment=exp,
+ data=exp.fetch_data(),
+ )
+
+ def test_error_scatter_data(self) -> None:
+ """
+ stub
+ """
+
+ def test_error_scatter_trace(self) -> None:
+ """
+ stub
+ """