From 3abe0423be74a294e15c933ed048e16e9e1a3f1b Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Wed, 10 Apr 2024 08:03:34 -0700 Subject: [PATCH] Tests which store and load dataframe and figure Summary: This diff adds tests for CrossValidationPlot - Stores a json representation of the dataframe of the plot to a tempfile, then reads it back, and asserts equality - Converts the plot to a json object, converts it back, then checks equality. We'll need to store analysis objects, so we need to check that the dataframe and figures are serializable to json Differential Revision: D55967859 --- .../tests/test_cross_validation_helpers.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py index 35cff0317a8..c83ecfe139a 100644 --- a/ax/analysis/helpers/tests/test_cross_validation_helpers.py +++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py @@ -4,8 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import tempfile + import plotly.graph_objects as go +import plotly.io as pio + from ax.analysis.cross_validation_plot import CrossValidationPlot from ax.analysis.helpers.constants import Z @@ -16,6 +20,9 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import fast_botorch_optimize +from pandas import read_json + +from pandas.testing import assert_frame_equal class TestCrossValidationHelpers(TestCase): @@ -63,3 +70,26 @@ def test_obs_vs_pred_dropdown_plot(self) -> None: fig = cross_validation_plot.get_fig() self.assertIsInstance(fig, go.Figure) + + def test_store_df_to_file(self) -> None: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + cross_validation_plot = CrossValidationPlot( + experiment=self.exp, model=self.model + ) + cv_df = cross_validation_plot.get_df() + cv_df.to_json(f.name) + + loaded_dataframe = read_json(f.name, dtype={"arm_name": "str"}) + + assert_frame_equal(cv_df, loaded_dataframe, check_dtype=False) + + def test_store_plot_as_dict(self) -> None: + cross_validation_plot = CrossValidationPlot( + experiment=self.exp, model=self.model + ) + cv_fig = cross_validation_plot.get_fig() + + json_obj = pio.to_json(cv_fig, validate=True, remove_uids=False) + + loaded_json_obj = pio.from_json(json_obj, output_type="Figure") + self.assertEqual(cv_fig, loaded_json_obj)