From 89a3d0c9c45f20a32d08b1aeec159b553f860a48 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Thu, 8 Aug 2024 09:07:00 -0700 Subject: [PATCH 1/4] SQAAnalysisCard refactor Differential Revision: D60258359 --- ax/storage/sqa_store/decoder.py | 36 +---- ax/storage/sqa_store/encoder.py | 48 +----- ax/storage/sqa_store/sqa_classes.py | 89 ++++------- ax/storage/sqa_store/sqa_config.py | 9 +- ax/storage/sqa_store/tests/test_sqa_store.py | 24 --- .../tests/test_sqa_store_analysis.py | 139 ------------------ ax/storage/utils.py | 7 - 7 files changed, 35 insertions(+), 317 deletions(-) delete mode 100644 ax/storage/sqa_store/tests/test_sqa_store_analysis.py diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 3f016797ac7..1d7e0e9e238 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -14,10 +14,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import pandas as pd -import plotly.io as pio - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -54,7 +50,6 @@ from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, - SQAAnalysis, SQAArm, SQAData, SQAExperiment, @@ -67,16 +62,10 @@ SQATrial, ) from ax.storage.sqa_store.sqa_config import SQAConfig -from ax.storage.utils import ( - AnalysisType, - DomainType, - MetricIntent, - ParameterConstraintType, -) +from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none -from pandas import read_json from pyre_extensions import assert_is_instance from sqlalchemy.orm.exc import DetachedInstanceError @@ -993,29 +982,6 @@ def data_from_sqa( dat.db_id = data_sqa.id return dat - def analysis_from_sqa( - self, - analysis_sqa: SQAAnalysis, - experiment: Experiment, - ) -> BaseAnalysis: - """Convert SQLAlchemy Analysis to Ax Analysis Object.""" - # TODO: generalize solution for pd dataframe type casting of "arm_name" column. - if analysis_sqa.experiment_analysis_type == AnalysisType.PLOTLY_VISUALIZATION: - return BasePlotlyVisualization( - experiment=experiment, - df_input=read_json( - analysis_sqa.dataframe_json, dtype={"arm_name": "str"} - ), - fig_input=pio.from_json(analysis_sqa.fig_json, output_type="Figure"), - ) - else: - return BaseAnalysis( - experiment=experiment, - df_input=read_json( - analysis_sqa.dataframe_json, dtype={"arm_name": "str"} - ), - ) - def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric: """Convert SQLAlchemy Metric to Ax Metric""" if metric_sqa.metric_type not in self.config.reverse_metric_registry: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 1f2857e5b90..f40fa0aebf4 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -11,12 +11,6 @@ from logging import Logger from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union -import plotly -import plotly.io as pio - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - from ax.core.arm import Arm from ax.core.base_trial import BaseTrial from ax.core.batch_trial import AbandonedArm, BatchTrial @@ -51,7 +45,6 @@ from ax.storage.json_store.encoder import object_to_json from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, - SQAAnalysis, SQAArm, SQAData, SQAExperiment, @@ -64,12 +57,7 @@ SQATrial, ) from ax.storage.sqa_store.sqa_config import SQAConfig -from ax.storage.utils import ( - AnalysisType, - DomainType, - MetricIntent, - ParameterConstraintType, -) +from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger @@ -1063,37 +1051,3 @@ def data_to_sqa( ) ), ) - - def analysis_to_sqa( - self, - analysis: BaseAnalysis, - ) -> SQAAnalysis: - """Convert Ax analysis to SQLAlchemy.""" - # pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`. - analysis_class: SQAAnalysis = self.config.class_to_sqa_class[BaseAnalysis] - - is_plotly_visualization: bool = isinstance(analysis, BasePlotlyVisualization) - - # pyre-fixme[29]: `SQAAnalysis` is not a function. - return analysis_class( - id=-1, - analysis_class_name=type(analysis).__name__, - time_analysis_start=-1, - time_analysis_completed=-1, - experiment_analysis_type=( - AnalysisType.PLOTLY_VISUALIZATION - if is_plotly_visualization - else AnalysisType.ANALYSIS - ), - dataframe_json=analysis.df.to_json(), - fig_json=( - None - if not is_plotly_visualization - else pio.to_json( - checked_cast(BasePlotlyVisualization, analysis).fig, - validate=True, - remove_uids=False, - ) - ), - plotly_version=plotly.__version__, - ) diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 430a6db3535..7bc978665b7 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -9,6 +9,8 @@ from datetime import datetime from typing import Any, Dict, List, Optional +from ax.analysis.analysis import AnalysisCardLevel + from ax.core.base_trial import TrialStatus from ax.core.batch_trial import LifecycleStage from ax.core.parameter import ParameterType @@ -34,13 +36,7 @@ ) from ax.storage.sqa_store.sqa_enum import IntEnum, StringEnum from ax.storage.sqa_store.timestamp import IntTimestamp -from ax.storage.utils import ( - AnalysisType, - DataType, - DomainType, - MetricIntent, - ParameterConstraintType, -) +from ax.storage.utils import DataType, DomainType, MetricIntent, ParameterConstraintType from sqlalchemy import ( BigInteger, Boolean, @@ -462,6 +458,31 @@ class SQATrial(Base): ) +class SQAAnalysisCard(Base): + __tablename__: str = "analysis_card" + + # pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`. + id: int = Column(Integer, primary_key=True) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + title: str = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + subtitle: str = Column(Text, nullable=False) + # pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`. + level: AnalysisCardLevel = Column(IntEnum(AnalysisCardLevel), nullable=False) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + blob: str = Column(Text(LONGTEXT_BYTES), nullable=False) + # pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`. + blob_annotation: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + # pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`. + time_created: datetime = Column(IntTimestamp, nullable=False) + # pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`. + experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id"), nullable=False) + + class SQAExperiment(Base): __tablename__: str = "experiment_v2" @@ -520,56 +541,6 @@ class SQAExperiment(Base): uselist=False, lazy=True, ) - - -class SQAAnalysis(Base): - __tablename__: str = "analysis" - - # pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`. - id: int = Column(Integer, primary_key=True) - - # pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`. - analysis_class_name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - # pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`. - time_analysis_start: datetime = Column(IntTimestamp, nullable=False) - # pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`. - time_analysis_completed: datetime = Column(IntTimestamp, nullable=False) - - # pyre-fixme[8]: Attribute has type `AnalysisType`; used as - # `Column[typing.Any]`. - experiment_analysis_type: AnalysisType = Column( - StringEnum(AnalysisType), nullable=False - ) - - # pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`. - dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False) - - # pyre-fixme[8]: Attribute has type `Optional[str]`; used as - # `Column[Optional[str]]`. - fig_json: Optional[str] = Column(Text(LONGTEXT_BYTES), nullable=True) - # pyre-fixme[8]: Attribute has type `Optional[str]`; used as - # `Column[Optional[str]]`. - plotly_version: Optional[str] = Column( - String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True - ) - - # pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`. - experiment_id: int = Column(Integer) - # pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`. - analysis_report_id: int = Column(Integer, ForeignKey("analysis_report_v2.id")) - - -class SQAAnalysisReport(Base): - __tablename__: str = "analysis_report_v2" - - # pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`. - id: int = Column(Integer, primary_key=True) - # pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`. - time_report_start: datetime = Column(IntTimestamp, nullable=False) - - analyses: Optional[List[SQAAnalysis]] = relationship( - "SQAAnalysis", cascade="all, delete-orphan", lazy="selectin" + analysis_cards: List[SQAAnalysisCard] = relationship( + "SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin" ) - - # pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`. - experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id")) diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index b3a5905f430..460cf486c7f 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -10,8 +10,7 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Type, Union -from ax.analysis.old.analysis_report import AnalysisReport -from ax.analysis.old.base_analysis import BaseAnalysis +from ax.analysis.analysis import AnalysisCard from ax.core.arm import Arm from ax.core.batch_trial import AbandonedArm @@ -36,8 +35,7 @@ from ax.storage.sqa_store.db import SQABase from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, - SQAAnalysis, - SQAAnalysisReport, + SQAAnalysisCard, SQAArm, SQAData, SQAExperiment, @@ -80,8 +78,7 @@ def _default_class_to_sqa_class(self=None) -> Dict[Type[Base], Type[SQABase]]: Metric: SQAMetric, Runner: SQARunner, Trial: SQATrial, - BaseAnalysis: SQAAnalysis, - AnalysisReport: SQAAnalysisReport, + AnalysisCard: SQAAnalysisCard, } class_to_sqa_class: Dict[Type[Base], Type[SQABase]] = field( diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 0c40611ca0e..038045805f4 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -64,10 +64,7 @@ update_runner_on_experiment, ) from ax.storage.sqa_store.sqa_classes import ( - AnalysisType, SQAAbandonedArm, - SQAAnalysis, - SQAAnalysisReport, SQAArm, SQAExperiment, SQAGeneratorRun, @@ -1929,24 +1926,3 @@ def test_CreateAllTablesException(self) -> None: engine.dialect.default_schema_name = "ax" with self.assertRaises(ValueError): create_all_tables(engine) - - def test_CreateAnalysisRecords(self) -> None: - - sqa_analysis = SQAAnalysis( - analysis_class_name="CrossValidationPlot", - experiment_analysis_type=AnalysisType.PLOTLY_VISUALIZATION, - time_analysis_start=datetime.now(), - time_analysis_completed=datetime.now(), - dataframe_json="none", - ) - with session_scope() as session: - _ = session.merge(sqa_analysis) - session.flush() - - def test_CreateAnalysisReport(self) -> None: - sqa_analysis_report = SQAAnalysisReport( - time_report_start=datetime.now(), - ) - with session_scope() as session: - _ = session.merge(sqa_analysis_report) - session.flush() diff --git a/ax/storage/sqa_store/tests/test_sqa_store_analysis.py b/ax/storage/sqa_store/tests/test_sqa_store_analysis.py deleted file mode 100644 index ae44f973ce8..00000000000 --- a/ax/storage/sqa_store/tests/test_sqa_store_analysis.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from logging import Logger - -import pandas as pd -import plotly.graph_objects as go - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.analysis.old.cross_validation_plot import CrossValidationPlot - -from ax.modelbridge.registry import Models - -from ax.storage.sqa_store.db import init_test_engine_and_session_factory -from ax.storage.sqa_store.decoder import Decoder -from ax.storage.sqa_store.encoder import Encoder -from ax.storage.sqa_store.load import ( - _get_generation_strategy_sqa_immutable_opt_config_and_search_space, -) - -from ax.storage.sqa_store.sqa_config import SQAConfig - -from ax.storage.utils import AnalysisType - -from ax.utils.common.logger import get_logger - -from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast -from ax.utils.testing.core_stubs import ( - get_branin_experiment, - get_experiment_with_batch_trial, - get_range_parameter, - get_range_parameter2, -) - -from pandas.testing import assert_frame_equal - -logger: Logger = get_logger(__name__) - -GET_GS_SQA_IMM_FUNC = _get_generation_strategy_sqa_immutable_opt_config_and_search_space - - -class SQAStoreTest(TestCase): - def setUp(self) -> None: - super().setUp() - init_test_engine_and_session_factory(force_init=True) - self.config = SQAConfig() - self.encoder = Encoder(config=self.config) - self.decoder = Decoder(config=self.config) - self.experiment = get_experiment_with_batch_trial() - self.dummy_parameters = [ - get_range_parameter(), # w - get_range_parameter2(), # x - ] - - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - def test_EncodeCrossValidationPlot(self) -> None: - plot = CrossValidationPlot(experiment=self.exp, model=self.model) - - sqa_analysis = self.encoder.analysis_to_sqa(analysis=plot) - - self.assertIn("CrossValidationPlot", sqa_analysis.analysis_class_name) - self.assertEqual( - AnalysisType.PLOTLY_VISUALIZATION, sqa_analysis.experiment_analysis_type - ) - - def test_EncodeBaseAnalysis(self) -> None: - analysis = BaseAnalysis(experiment=self.exp, df_input=pd.DataFrame()) - - sqa_analysis = self.encoder.analysis_to_sqa(analysis=analysis) - - self.assertIn("BaseAnalysis", sqa_analysis.analysis_class_name) - self.assertEqual(AnalysisType.ANALYSIS, sqa_analysis.experiment_analysis_type) - - def test_DecodeBaseAnalysis(self) -> None: - df = pd.DataFrame() - analysis = BaseAnalysis(experiment=self.exp, df_input=df) - - sqa_analysis = self.encoder.analysis_to_sqa(analysis=analysis) - - self.assertIn("BaseAnalysis", sqa_analysis.analysis_class_name) - self.assertEqual(AnalysisType.ANALYSIS, sqa_analysis.experiment_analysis_type) - - decoded_analysis = self.decoder.analysis_from_sqa( - experiment=self.exp, analysis_sqa=sqa_analysis - ) - self.assertFalse(isinstance(decoded_analysis, BasePlotlyVisualization)) - - # throws if not equal - assert_frame_equal(df, decoded_analysis.df, check_dtype=False) - - def test_EncodeBasePlotlyVisualization(self) -> None: - analysis = BasePlotlyVisualization( - experiment=self.exp, df_input=pd.DataFrame(), fig_input=go.Figure() - ) - - sqa_analysis = self.encoder.analysis_to_sqa(analysis=analysis) - - self.assertIn("BasePlotlyVisualization", sqa_analysis.analysis_class_name) - self.assertEqual( - AnalysisType.PLOTLY_VISUALIZATION, sqa_analysis.experiment_analysis_type - ) - - def test_DecodeCrossValidationPlot(self) -> None: - plot = CrossValidationPlot(experiment=self.exp, model=self.model) - - df = plot.get_df() - fig = plot.get_fig() - - sqa_analysis = self.encoder.analysis_to_sqa(analysis=plot) - - self.assertIn("CrossValidationPlot", sqa_analysis.analysis_class_name) - self.assertEqual( - AnalysisType.PLOTLY_VISUALIZATION, sqa_analysis.experiment_analysis_type - ) - - decoded_plot = self.decoder.analysis_from_sqa( - experiment=self.exp, analysis_sqa=sqa_analysis - ) - self.assertTrue(isinstance(decoded_plot, BasePlotlyVisualization)) - decoded_fig = checked_cast(BasePlotlyVisualization, decoded_plot) - # throws if not equal - assert_frame_equal(df, decoded_fig.df, check_dtype=False) - self.assertEqual(fig, decoded_fig.fig) - # add the equal check of the plot diff --git a/ax/storage/utils.py b/ax/storage/utils.py index 0e533004d85..9d0b33d8f1c 100644 --- a/ax/storage/utils.py +++ b/ax/storage/utils.py @@ -60,10 +60,3 @@ def stable_hash(s: str) -> int: int: Hash, converted to an integer. """ return int(md5(s.encode("utf-8")).hexdigest(), 16) - - -class AnalysisType(enum.Enum): - """Class for enumerating different experiment analysis types.""" - - ANALYSIS: str = "analysis" - PLOTLY_VISUALIZATION: str = "plotly_visualization" From 98d7d9f35d9f5287f3dde55c258aad1e7effa624 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Thu, 8 Aug 2024 09:07:00 -0700 Subject: [PATCH 2/4] AnalysisCard encoder/decoder refactor Differential Revision: D60316107 --- ax/analysis/analysis.py | 3 ++- ax/storage/sqa_store/decoder.py | 17 +++++++++++++++++ ax/storage/sqa_store/encoder.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index f2e01bbc71f..1693acbdfc4 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -11,6 +11,7 @@ import pandas as pd from ax.core.experiment import Experiment from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.utils.common.base import Base class AnalysisCardLevel(Enum): @@ -21,7 +22,7 @@ class AnalysisCardLevel(Enum): CRITICAL = 4 -class AnalysisCard: +class AnalysisCard(Base): # Name of the analysis computed, usually the class name of the Analysis which # produced the card. Useful for grouping by when querying a large collection of # cards. diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 1d7e0e9e238..e3e754c97ce 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -14,6 +14,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import pandas as pd +from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -50,6 +51,7 @@ from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, + SQAAnalysisCard, SQAArm, SQAData, SQAExperiment, @@ -66,6 +68,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none +from pandas import read_json from pyre_extensions import assert_is_instance from sqlalchemy.orm.exc import DetachedInstanceError @@ -982,6 +985,20 @@ def data_from_sqa( dat.db_id = data_sqa.id return dat + def analysis_card_from_sqa( + self, + analysis_card_sqa: SQAAnalysisCard, + ) -> AnalysisCard: + """Convert SQLAlchemy Analysis to Ax Analysis Object.""" + return AnalysisCard( + name=analysis_card_sqa.name, + title=analysis_card_sqa.title, + subtitle=analysis_card_sqa.subtitle, + level=AnalysisCardLevel(analysis_card_sqa.level), + df=read_json(analysis_card_sqa.dataframe_json), + blob=analysis_card_sqa.blob, + ) + def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric: """Convert SQLAlchemy Metric to Ax Metric""" if metric_sqa.metric_type not in self.config.reverse_metric_registry: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index f40fa0aebf4..ba85586fef9 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -6,11 +6,14 @@ # pyre-strict +from datetime import datetime from enum import Enum from logging import Logger from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from ax.analysis.analysis import AnalysisCard + from ax.core.arm import Arm from ax.core.base_trial import BaseTrial from ax.core.batch_trial import AbandonedArm, BatchTrial @@ -45,6 +48,7 @@ from ax.storage.json_store.encoder import object_to_json from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, + SQAAnalysisCard, SQAArm, SQAData, SQAExperiment, @@ -1051,3 +1055,29 @@ def data_to_sqa( ) ), ) + + def analysis_card_to_sqa( + self, + analysis_card: AnalysisCard, + experiment_id: int, + timestamp: datetime, + ) -> SQAAnalysisCard: + """Convert Ax analysis to SQLAlchemy.""" + # pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`. + analysis_card_class: SQAAnalysisCard = self.config.class_to_sqa_class[ + AnalysisCard + ] + + # pyre-fixme[29]: `SQAAnalysisCard` is not a function. + return analysis_card_class( + id=analysis_card.db_id, + name=analysis_card.name, + title=analysis_card.title, + subtitle=analysis_card.subtitle, + level=analysis_card.level, + dataframe_json=analysis_card.df.to_json(), + blob=analysis_card.blob, + blob_annotation=analysis_card.blob_annotation, + time_created=timestamp, + experiment_id=experiment_id, + ) From 712ba5b0a9a194c78f6f81a1794b2cec1913334a Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Thu, 8 Aug 2024 09:07:00 -0700 Subject: [PATCH 3/4] AnalysisCard load/save methods Differential Revision: D60321245 --- ax/storage/sqa_store/load.py | 34 ++++++++++- ax/storage/sqa_store/save.py | 63 ++++++++++++++++--- ax/storage/sqa_store/tests/test_sqa_store.py | 64 ++++++++++++++++++++ 3 files changed, 151 insertions(+), 10 deletions(-) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index b8382c521c6..912f3160242 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -9,6 +9,8 @@ from math import ceil from typing import Any, cast, Dict, List, Optional, Type +from ax.analysis.analysis import AnalysisCard + from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric @@ -22,6 +24,7 @@ get_query_options_to_defer_large_model_cols, ) from ax.storage.sqa_store.sqa_classes import ( + SQAAnalysisCard, SQAExperiment, SQAGenerationStrategy, SQAGeneratorRun, @@ -63,7 +66,7 @@ def load_experiment( of metrics, this option converts the loaded metrics into a base metric avoiding conversion related to custom properties of the metric. """ - config = config or SQAConfig() + config = SQAConfig() if config is None else config decoder = Decoder(config=config) return _load_experiment( experiment_name=experiment_name, @@ -363,7 +366,7 @@ def load_generation_strategy_by_experiment_name( """Finds a generation strategy attached to an experiment specified by a name and restores it from its corresponding SQA object. """ - config = config or SQAConfig() + config = SQAConfig() if config is None else config decoder = Decoder(config=config) return _load_generation_strategy_by_experiment_name( experiment_name=experiment_name, @@ -381,7 +384,7 @@ def load_generation_strategy_by_id( reduced_state: bool = False, ) -> GenerationStrategy: """Finds a generation strategy stored by a given ID and restores it.""" - config = config or SQAConfig() + config = SQAConfig() if config is None else config decoder = Decoder(config=config) return _load_generation_strategy_by_id( gs_id=gs_id, decoder=decoder, experiment=experiment, reduced_state=reduced_state @@ -584,3 +587,28 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space( lazyload("generator_runs.metrics"), ], ) + + +def load_analysis_cards_by_experiment_name( + experiment_name: str, + config: Optional[SQAConfig] = None, +) -> List[AnalysisCard]: + """Loads analysis cards for an experiment.""" + config = SQAConfig() if config is None else config + decoder = Decoder(config=config) + analysis_card_sqa_class: SQAAnalysisCard = cast( + SQAAnalysisCard, decoder.config.class_to_sqa_class[AnalysisCard] + ) + exp_sqa_class: SQAExperiment = cast( + SQAExperiment, decoder.config.class_to_sqa_class[Experiment] + ) + with session_scope() as session: + analysis_cards_sqa = ( + session.query(analysis_card_sqa_class) + .join(exp_sqa_class.analysis_cards) + .filter(exp_sqa_class.name == experiment_name) + ) + return [ + decoder.analysis_card_from_sqa(analysis_card_sqa=analysis_card_sqa) + for analysis_card_sqa in analysis_cards_sqa + ] diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index 19c662699e3..4997a940f71 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -7,10 +7,13 @@ # pyre-strict import os +from datetime import datetime from logging import Logger from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union +from ax.analysis.analysis import AnalysisCard + from ax.core.base_trial import BaseTrial from ax.core.data import Data from ax.core.experiment import Experiment @@ -47,7 +50,7 @@ def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None) raise ValueError("Can only save instances of Experiment") if not experiment.has_name: raise ValueError("Experiment name must be set prior to saving.") - config = config or SQAConfig() + config = SQAConfig() if config is None else config encoder = Encoder(config=config) decoder = Decoder(config=config) _save_experiment(experiment=experiment, encoder=encoder, decoder=decoder) @@ -107,7 +110,7 @@ def save_generation_strategy( The ID of the saved generation strategy. """ # Start up SQA encoder. - config = config or SQAConfig() + config = SQAConfig() if config is None else config encoder = Encoder(config=config) decoder = Decoder(config=config) @@ -150,7 +153,7 @@ def save_or_update_trial( ) -> None: """Add new trial to the experiment, or update if already exists (using default SQAConfig).""" - config = config or SQAConfig() + config = SQAConfig() if config is None else config encoder = Encoder(config=config) decoder = Decoder(config=config) _save_or_update_trial( @@ -189,7 +192,7 @@ def save_or_update_trials( will also be added to the experiment, but existing data objects in the database will *not* be updated or removed. """ - config = config or SQAConfig() + config = SQAConfig() if config is None else config encoder = Encoder(config=config) decoder = Decoder(config=config) _save_or_update_trials( @@ -308,7 +311,7 @@ def update_generation_strategy( ) -> None: """Update generation strategy's current step and attach generator runs (using default SQAConfig).""" - config = config or SQAConfig() + config = SQAConfig() if config is None else config encoder = Encoder(config=config) decoder = Decoder(config=config) _update_generation_strategy( @@ -450,7 +453,7 @@ def update_properties_on_experiment( experiment_with_updated_properties: Experiment, config: Optional[SQAConfig] = None, ) -> None: - config = config or SQAConfig() + config = SQAConfig() if config is None else config exp_sqa_class = config.class_to_sqa_class[Experiment] exp_id = experiment_with_updated_properties.db_id @@ -469,7 +472,7 @@ def update_properties_on_trial( trial_with_updated_properties: BaseTrial, config: Optional[SQAConfig] = None, ) -> None: - config = config or SQAConfig() + config = SQAConfig() if config is None else config trial_sqa_class = config.class_to_sqa_class[Trial] trial_id = trial_with_updated_properties.db_id @@ -484,6 +487,52 @@ def update_properties_on_trial( ) +def save_analysis_cards( + analysis_cards: List[AnalysisCard], + experiment: Experiment, + config: Optional[SQAConfig] = None, +) -> None: + # Start up SQA encoder. + config = SQAConfig() if config is None else config + encoder = Encoder(config=config) + decoder = Decoder(config=config) + timestamp = datetime.utcnow() + _save_analysis_cards( + analysis_cards=analysis_cards, + experiment=experiment, + timestamp=timestamp, + encoder=encoder, + decoder=decoder, + ) + + +def _save_analysis_cards( + analysis_cards: List[AnalysisCard], + experiment: Experiment, + timestamp: datetime, + encoder: Encoder, + decoder: Decoder, +) -> None: + if any(analysis_card.db_id is not None for analysis_card in analysis_cards): + raise ValueError("Analysis cards cannot be updated.") + if experiment.db_id is None: + raise ValueError( + f"Experiment {experiment.name} should be saved before analysis cards." + ) + _bulk_merge_into_session( + objs=analysis_cards, + encode_func=encoder.analysis_card_to_sqa, + decode_func=decoder.analysis_card_from_sqa, + encode_args_list=[ + { + "experiment_id": experiment.db_id, + "timestamp": timestamp, + } + for _analysis_card in analysis_cards + ], + ) + + def _merge_into_session( obj: Base, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 038045805f4..daccf4df0d9 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -14,6 +14,11 @@ from unittest import mock from unittest.mock import MagicMock, Mock, patch +import pandas as pd + +from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.core.arm import Arm from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.generator_run import GeneratorRun @@ -48,12 +53,14 @@ _get_experiment_immutable_opt_config_and_search_space, _get_experiment_sqa_immutable_opt_config_and_search_space, _get_generation_strategy_sqa_immutable_opt_config_and_search_space, + load_analysis_cards_by_experiment_name, load_experiment, load_generation_strategy_by_experiment_name, load_generation_strategy_by_id, ) from ax.storage.sqa_store.reduced_state import GR_LARGE_MODEL_ATTRS from ax.storage.sqa_store.save import ( + save_analysis_cards, save_experiment, save_generation_strategy, save_or_update_trial, @@ -114,6 +121,7 @@ get_synthetic_runner, ) from ax.utils.testing.modeling_stubs import get_generation_strategy +from plotly import graph_objects as go, io as pio logger: Logger = get_logger(__name__) @@ -1926,3 +1934,59 @@ def test_CreateAllTablesException(self) -> None: engine.dialect.default_schema_name = "ax" with self.assertRaises(ValueError): create_all_tables(engine) + + def test_AnalysisCard(self) -> None: + test_df = pd.DataFrame( + columns=["a", "b"], + data=[ + [1, 2], + [3, 4], + ], + ) + base_analysis_card = AnalysisCard( + name="test_base_analysis_card", + title="test_title", + subtitle="test_subtitle", + level=AnalysisCardLevel.DEBUG, + df=test_df, + blob="test blob", + ) + markdown_analysis_card = MarkdownAnalysisCard( + name="test_markdown_analysis_card", + title="test_title", + subtitle="test_subtitle", + level=AnalysisCardLevel.DEBUG, + df=test_df, + blob="This is some **really cool** markdown", + ) + plotly_analysis_card = PlotlyAnalysisCard( + name="test_plotly_analysis_card", + title="test_title", + subtitle="test_subtitle", + level=AnalysisCardLevel.DEBUG, + df=test_df, + blob=pio.to_json(go.Figure()), + ) + with self.subTest("test_save_analysis_cards"): + save_experiment(self.experiment) + save_analysis_cards( + [base_analysis_card, markdown_analysis_card, plotly_analysis_card], + self.experiment, + ) + with self.subTest("test_load_analysis_cards"): + loaded_analysis_cards = load_analysis_cards_by_experiment_name( + self.experiment.name + ) + self.assertEqual(len(loaded_analysis_cards), 3) + self.assertEqual( + loaded_analysis_cards[0].blob, + base_analysis_card.blob, + ) + self.assertEqual( + loaded_analysis_cards[1].blob, + markdown_analysis_card.blob, + ) + self.assertEqual( + loaded_analysis_cards[2].blob, + plotly_analysis_card.blob, + ) From 3d0b80a75ae6fd10a937e91e9efb87300c4d2790 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Thu, 8 Aug 2024 12:08:09 -0700 Subject: [PATCH 4/4] Healthcheck analysis class (#2646) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2646 Add abstract class for health checks Reviewed By: mpolson64 Differential Revision: D60417357 --- ax/analysis/healthcheck/__init__.py | 18 +++++++++ .../healthcheck/healthcheck_analysis.py | 38 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 ax/analysis/healthcheck/__init__.py create mode 100644 ax/analysis/healthcheck/healthcheck_analysis.py diff --git a/ax/analysis/healthcheck/__init__.py b/ax/analysis/healthcheck/__init__.py new file mode 100644 index 00000000000..2d99e3e9a0d --- /dev/null +++ b/ax/analysis/healthcheck/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.healthcheck.healthcheck_analysis import ( + HealthcheckAnalysis, + HealthcheckAnalysisCard, + HealthcheckStatus, +) + +__all__ = [ + "HealthcheckAnalysis", + "HealthcheckAnalysisCard", + "HealthcheckStatus", +] diff --git a/ax/analysis/healthcheck/healthcheck_analysis.py b/ax/analysis/healthcheck/healthcheck_analysis.py new file mode 100644 index 00000000000..c7fd74d025e --- /dev/null +++ b/ax/analysis/healthcheck/healthcheck_analysis.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import json +from enum import IntEnum +from typing import Optional + +from ax.analysis.analysis import AnalysisCard +from ax.core.experiment import Experiment +from ax.modelbridge.generation_strategy import GenerationStrategy + + +class HealthcheckStatus(IntEnum): + PASS = 0 + FAIL = 1 + WARNING = 2 + + +class HealthcheckAnalysisCard(AnalysisCard): + blob_annotation = "healthcheck" + + def get_status(self) -> HealthcheckStatus: + return HealthcheckStatus(json.loads(self.blob)["status"]) + + +class HealthcheckAnalysis: + """ + An analysis that performs a health check. + """ + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategy] = None, + ) -> HealthcheckAnalysisCard: ...