Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Healthcheck analysis class #2646

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.utils.common.base import Base


class AnalysisCardLevel(Enum):
Expand All @@ -21,7 +22,7 @@ class AnalysisCardLevel(Enum):
CRITICAL = 4


class AnalysisCard:
class AnalysisCard(Base):
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
# cards.
Expand Down
18 changes: 18 additions & 0 deletions ax/analysis/healthcheck/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
HealthcheckStatus,
)

__all__ = [
"HealthcheckAnalysis",
"HealthcheckAnalysisCard",
"HealthcheckStatus",
]
38 changes: 38 additions & 0 deletions ax/analysis/healthcheck/healthcheck_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import json
from enum import IntEnum
from typing import Optional

from ax.analysis.analysis import AnalysisCard
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy


class HealthcheckStatus(IntEnum):
PASS = 0
FAIL = 1
WARNING = 2


class HealthcheckAnalysisCard(AnalysisCard):
blob_annotation = "healthcheck"

def get_status(self) -> HealthcheckStatus:
return HealthcheckStatus(json.loads(self.blob)["status"])


class HealthcheckAnalysis:
"""
An analysis that performs a health check.
"""

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> HealthcheckAnalysisCard: ...
45 changes: 14 additions & 31 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -54,7 +51,7 @@
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -67,12 +64,7 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -993,28 +985,19 @@ def data_from_sqa(
dat.db_id = data_sqa.id
return dat

def analysis_from_sqa(
def analysis_card_from_sqa(
self,
analysis_sqa: SQAAnalysis,
experiment: Experiment,
) -> BaseAnalysis:
analysis_card_sqa: SQAAnalysisCard,
) -> AnalysisCard:
"""Convert SQLAlchemy Analysis to Ax Analysis Object."""
# TODO: generalize solution for pd dataframe type casting of "arm_name" column.
if analysis_sqa.experiment_analysis_type == AnalysisType.PLOTLY_VISUALIZATION:
return BasePlotlyVisualization(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
fig_input=pio.from_json(analysis_sqa.fig_json, output_type="Figure"),
)
else:
return BaseAnalysis(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
)
return AnalysisCard(
name=analysis_card_sqa.name,
title=analysis_card_sqa.title,
subtitle=analysis_card_sqa.subtitle,
level=AnalysisCardLevel(analysis_card_sqa.level),
df=read_json(analysis_card_sqa.dataframe_json),
blob=analysis_card_sqa.blob,
)

def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
"""Convert SQLAlchemy Metric to Ax Metric"""
Expand Down
66 changes: 25 additions & 41 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@

# pyre-strict

from datetime import datetime
from enum import Enum

from logging import Logger
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import plotly
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
Expand Down Expand Up @@ -51,7 +48,7 @@
from ax.storage.json_store.encoder import object_to_json
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -64,12 +61,7 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -1064,36 +1056,28 @@ def data_to_sqa(
),
)

def analysis_to_sqa(
def analysis_card_to_sqa(
self,
analysis: BaseAnalysis,
) -> SQAAnalysis:
analysis_card: AnalysisCard,
experiment_id: int,
timestamp: datetime,
) -> SQAAnalysisCard:
"""Convert Ax analysis to SQLAlchemy."""
# pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`.
analysis_class: SQAAnalysis = self.config.class_to_sqa_class[BaseAnalysis]

is_plotly_visualization: bool = isinstance(analysis, BasePlotlyVisualization)

# pyre-fixme[29]: `SQAAnalysis` is not a function.
return analysis_class(
id=-1,
analysis_class_name=type(analysis).__name__,
time_analysis_start=-1,
time_analysis_completed=-1,
experiment_analysis_type=(
AnalysisType.PLOTLY_VISUALIZATION
if is_plotly_visualization
else AnalysisType.ANALYSIS
),
dataframe_json=analysis.df.to_json(),
fig_json=(
None
if not is_plotly_visualization
else pio.to_json(
checked_cast(BasePlotlyVisualization, analysis).fig,
validate=True,
remove_uids=False,
)
),
plotly_version=plotly.__version__,
analysis_card_class: SQAAnalysisCard = self.config.class_to_sqa_class[
AnalysisCard
]

# pyre-fixme[29]: `SQAAnalysisCard` is not a function.
return analysis_card_class(
id=analysis_card.db_id,
name=analysis_card.name,
title=analysis_card.title,
subtitle=analysis_card.subtitle,
level=analysis_card.level,
dataframe_json=analysis_card.df.to_json(),
blob=analysis_card.blob,
blob_annotation=analysis_card.blob_annotation,
time_created=timestamp,
experiment_id=experiment_id,
)
34 changes: 31 additions & 3 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from math import ceil
from typing import Any, cast, Dict, List, Optional, Type

from ax.analysis.analysis import AnalysisCard

from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand All @@ -22,6 +24,7 @@
get_query_options_to_defer_large_model_cols,
)
from ax.storage.sqa_store.sqa_classes import (
SQAAnalysisCard,
SQAExperiment,
SQAGenerationStrategy,
SQAGeneratorRun,
Expand Down Expand Up @@ -63,7 +66,7 @@ def load_experiment(
of metrics, this option converts the loaded metrics into a base
metric avoiding conversion related to custom properties of the metric.
"""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_experiment(
experiment_name=experiment_name,
Expand Down Expand Up @@ -363,7 +366,7 @@ def load_generation_strategy_by_experiment_name(
"""Finds a generation strategy attached to an experiment specified by a name
and restores it from its corresponding SQA object.
"""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name,
Expand All @@ -381,7 +384,7 @@ def load_generation_strategy_by_id(
reduced_state: bool = False,
) -> GenerationStrategy:
"""Finds a generation strategy stored by a given ID and restores it."""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_generation_strategy_by_id(
gs_id=gs_id, decoder=decoder, experiment=experiment, reduced_state=reduced_state
Expand Down Expand Up @@ -584,3 +587,28 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space(
lazyload("generator_runs.metrics"),
],
)


def load_analysis_cards_by_experiment_name(
experiment_name: str,
config: Optional[SQAConfig] = None,
) -> List[AnalysisCard]:
"""Loads analysis cards for an experiment."""
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
analysis_card_sqa_class: SQAAnalysisCard = cast(
SQAAnalysisCard, decoder.config.class_to_sqa_class[AnalysisCard]
)
exp_sqa_class: SQAExperiment = cast(
SQAExperiment, decoder.config.class_to_sqa_class[Experiment]
)
with session_scope() as session:
analysis_cards_sqa = (
session.query(analysis_card_sqa_class)
.join(exp_sqa_class.analysis_cards)
.filter(exp_sqa_class.name == experiment_name)
)
return [
decoder.analysis_card_from_sqa(analysis_card_sqa=analysis_card_sqa)
for analysis_card_sqa in analysis_cards_sqa
]
Loading
Loading