diff --git a/ax/analysis/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/helpers/tests/test_cross_validation_helpers.py index 35cff0317a8..c83ecfe139a 100644 --- a/ax/analysis/helpers/tests/test_cross_validation_helpers.py +++ b/ax/analysis/helpers/tests/test_cross_validation_helpers.py @@ -4,8 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import tempfile + import plotly.graph_objects as go +import plotly.io as pio + from ax.analysis.cross_validation_plot import CrossValidationPlot from ax.analysis.helpers.constants import Z @@ -16,6 +20,9 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import fast_botorch_optimize +from pandas import read_json + +from pandas.testing import assert_frame_equal class TestCrossValidationHelpers(TestCase): @@ -63,3 +70,26 @@ def test_obs_vs_pred_dropdown_plot(self) -> None: fig = cross_validation_plot.get_fig() self.assertIsInstance(fig, go.Figure) + + def test_store_df_to_file(self) -> None: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + cross_validation_plot = CrossValidationPlot( + experiment=self.exp, model=self.model + ) + cv_df = cross_validation_plot.get_df() + cv_df.to_json(f.name) + + loaded_dataframe = read_json(f.name, dtype={"arm_name": "str"}) + + assert_frame_equal(cv_df, loaded_dataframe, check_dtype=False) + + def test_store_plot_as_dict(self) -> None: + cross_validation_plot = CrossValidationPlot( + experiment=self.exp, model=self.model + ) + cv_fig = cross_validation_plot.get_fig() + + json_obj = pio.to_json(cv_fig, validate=True, remove_uids=False) + + loaded_json_obj = pio.from_json(json_obj, output_type="Figure") + self.assertEqual(cv_fig, loaded_json_obj)