Skip to content

Commit

Permalink
Allow more flexible definition of which trial statuses to fit. (faceb…
Browse files Browse the repository at this point in the history
…ook#2432)

Summary:

Changes `observations_from_data` and `observations_from_map_data` signatures, replacing arg `fit_abandoned: bool` -> `statuses_to_fit` and `statuses_to_fit_map_metric` which are both `Optional[List[TrialStatus]]`. In `observations_from_data`, these default to `None` in which case `{TrialStatus.COMPLETED}` is used for any MapMetrics, else `NON_ABANDONED_STATUSES`, i.e., all trial statuses except abandoned. In `observations_from_map_data`, `NON_ABANDONED_STATUSES` are the default for all metrics.

NOTE: As of this diff, any rows of `Data.df` containing `metric_names` that don't exist on `experiment` will be filtered out ([pointer](https://www.internalfb.com/diff/D56634321?permalink=330145676765780)).

Reviewed By: saitcakmak

Differential Revision: D56634321
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed May 6, 2024
1 parent 3f2ad50 commit f2b1dda
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 40 deletions.
7 changes: 5 additions & 2 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union

from ax.core.arm import Arm
from ax.core.data import Data
Expand Down Expand Up @@ -76,7 +76,8 @@ class TrialStatus(int, Enum):
NOTE: Data for abandoned trials (or abandoned arms in batch trials) is
not passed to the model as part of training data, unless ``fit_abandoned``
option is specified to model bridge.
option is specified to model bridge. Additionally, data from MapMetrics is
typically excluded unless the corresponding trial is completed.
"""

CANDIDATE = 0
Expand Down Expand Up @@ -165,6 +166,8 @@ def __repr__(self) -> str:
TrialStatus.EARLY_STOPPED,
]

NON_ABANDONED_STATUSES: Set[TrialStatus] = set(TrialStatus) - {TrialStatus.ABANDONED}


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def immutable_once_run(func: Callable) -> Callable:
Expand Down
113 changes: 92 additions & 21 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import numpy as np
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.types import TCandidateMetadata, TParameterization
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
Expand Down Expand Up @@ -253,7 +255,8 @@ def _observations_from_dataframe(
cols: List[str],
arm_name_only: bool,
map_keys: Iterable[str],
include_abandoned: bool,
statuses_to_include: Set[TrialStatus],
statuses_to_include_map_metric: Set[TrialStatus],
map_keys_as_parameters: bool = False,
) -> List[Observation]:
"""Helper method for extracting observations grouped by `cols` from `df`.
Expand All @@ -262,10 +265,13 @@ def _observations_from_dataframe(
experiment: Experiment with arm parameters.
df: DataFrame derived from experiment Data.
cols: columns used to group data into different observations.
arm_name_only: whether arm_name is the only column in `cols`.
map_keys: columns that map dict-like Data
e.g. `timestamp` in timeseries data, `epoch` in ML training traces.
include_abandoned: Whether data for abandoned trials and arms should
be included in the observations, returned from this function.
statuses_to_include: data from non-MapMetrics will only be included for trials
with statuses in this set.
statuses_to_include_map_metric: data from MapMetrics will only be included for
trials with statuses in this set.
map_keys_as_parameters: Whether map_keys should be returned as part of
the parameters of the Observation objects.
Expand All @@ -287,8 +293,11 @@ def _observations_from_dataframe(
arm_name = features["arm_name"]
trial_index = features.get("trial_index", None)

is_arm_abandoned = False
trial_status = None
if trial_index is not None:
trial = experiment.trials[trial_index]
trial_status = trial.status
metadata = trial._get_candidate_metadata(arm_name) or {}
if Keys.TRIAL_COMPLETION_TIMESTAMP not in metadata:
if trial._time_completed is not None:
Expand All @@ -297,18 +306,15 @@ def _observations_from_dataframe(
).timestamp()
obs_kwargs[Keys.METADATA] = metadata

if not include_abandoned and trial.status.is_abandoned:
# Exclude abandoned trials.
continue

if not include_abandoned and isinstance(trial, BatchTrial):
# Exclude abandoned arms from batch trial's observations.
# Determine if this arm is abandoned.
is_arm_abandoned = trial.is_abandoned
if isinstance(trial, BatchTrial):
if trial.index not in abandoned_arms_dict:
# Same abandoned arm names to dict to avoid recomputing them
# on creation of every observation.
abandoned_arms_dict[trial.index] = trial.abandoned_arm_names
if arm_name in abandoned_arms_dict[trial.index]:
continue
is_arm_abandoned = True

obs_parameters = experiment.arms_by_name[arm_name].parameters.copy()
if obs_parameters:
Expand All @@ -332,6 +338,16 @@ def _observations_from_dataframe(
obs_parameters[map_key] = features[map_key]
else:
obs_kwargs[Keys.METADATA][map_key] = features[map_key]
d = _filter_data_on_status(
df=d,
experiment=experiment,
trial_status=trial_status,
is_arm_abandoned=is_arm_abandoned,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
)
if len(d) == 0:
continue
observations.append(
Observation(
features=ObservationFeatures(**obs_kwargs),
Expand All @@ -346,6 +362,42 @@ def _observations_from_dataframe(
return observations


def _filter_data_on_status(
df: pd.DataFrame,
experiment: experiment.Experiment,
trial_status: Optional[TrialStatus],
# Arms on a BatchTrial can be abandoned even if the BatchTrial is not.
# Data will be filtered out if is_arm_abandoned is True and the corresponding
# statuses_to_include does not contain TrialStatus.ABANDONED.
is_arm_abandoned: bool,
statuses_to_include: Set[TrialStatus],
statuses_to_include_map_metric: Set[TrialStatus],
) -> pd.DataFrame:
if "metric_name" not in df.columns:
raise ValueError(f"`metric_name` column is missing from {df!r}.")
dfs = []
for g, d in df.groupby(by="metric_name"):
metric_name = g
# Filter out any metrics that are not on the experiment.
if metric_name not in experiment.metrics:
continue
metric = experiment.metrics[metric_name]
statuses_to_include_metric = (
statuses_to_include_map_metric
if isinstance(metric, MapMetric)
else statuses_to_include
)
if trial_status is not None and trial_status not in statuses_to_include_metric:
continue
if is_arm_abandoned and TrialStatus.ABANDONED not in statuses_to_include_metric:
continue
dfs.append(d)
if len(dfs) == 0:
return pd.DataFrame()
df = pd.concat(dfs)
return df


def get_feature_cols(data: Data, is_map_data: bool = False) -> List[str]:
feature_cols = OBS_COLS.intersection(data.df.columns)
# note we use this check, rather than isinstance, since
Expand All @@ -369,7 +421,10 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> List[str]:


def observations_from_data(
experiment: experiment.Experiment, data: Data, include_abandoned: bool = False
experiment: experiment.Experiment,
data: Data,
statuses_to_include: Optional[Set[TrialStatus]] = None,
statuses_to_include_map_metric: Optional[Set[TrialStatus]] = None,
) -> List[Observation]:
"""Convert Data to observations.
Expand All @@ -382,13 +437,18 @@ def observations_from_data(
Args:
experiment: Experiment with arm parameters.
data: Data of observations.
include_abandoned: Whether data for abandoned trials and arms should
be included in the observations, returned from this function.
statuses_to_include: data from non-MapMetrics will only be included for trials
with statuses in this set. Defaults to all statuses except abandoned.
statuses_to_include_map_metric: data from MapMetrics will only be included for
trials with statuses in this set. Defaults to completed status only.
Returns:
List of Observation objects.
"""

if statuses_to_include is None:
statuses_to_include = NON_ABANDONED_STATUSES
if statuses_to_include_map_metric is None:
statuses_to_include_map_metric = {TrialStatus.COMPLETED}
feature_cols = get_feature_cols(data)
observations = []
arm_name_only = len(feature_cols) == 1 # there will always be an arm name
Expand Down Expand Up @@ -418,8 +478,9 @@ def observations_from_data(
df=complete_df,
cols=feature_cols,
arm_name_only=arm_name_only,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
include_abandoned=include_abandoned,
)
)
if incomplete_df is not None:
Expand All @@ -430,8 +491,9 @@ def observations_from_data(
df=incomplete_df,
cols=complete_feature_cols,
arm_name_only=arm_name_only,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
include_abandoned=include_abandoned,
)
)
return observations
Expand All @@ -440,7 +502,8 @@ def observations_from_data(
def observations_from_map_data(
experiment: experiment.Experiment,
map_data: MapData,
include_abandoned: bool = False,
statuses_to_include: Optional[Set[TrialStatus]] = None,
statuses_to_include_map_metric: Optional[Set[TrialStatus]] = None,
map_keys_as_parameters: bool = False,
limit_rows_per_metric: Optional[int] = None,
limit_rows_per_group: Optional[int] = None,
Expand All @@ -456,8 +519,10 @@ def observations_from_map_data(
Args:
experiment: Experiment with arm parameters.
map_data: MapData of observations.
include_abandoned: Whether data for abandoned trials and arms should
be included in the observations, returned from this function.
statuses_to_include: data from non-MapMetrics will only be included for trials
with statuses in this set. Defaults to all statuses except abandoned.
statuses_to_include_map_metric: data from MapMetrics will only be included for
trials with statuses in this set. Defaults to all statuses except abandoned.
map_keys_as_parameters: Whether map_keys should be returned as part of
the parameters of the Observation objects.
limit_rows_per_metric: If specified, uses MapData.subsample() with
Expand All @@ -473,6 +538,10 @@ def observations_from_map_data(
Returns:
List of Observation objects.
"""
if statuses_to_include is None:
statuses_to_include = NON_ABANDONED_STATUSES
if statuses_to_include_map_metric is None:
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
if limit_rows_per_metric is not None or limit_rows_per_group is not None:
map_data = map_data.subsample(
map_key=map_data.map_keys[0],
Expand Down Expand Up @@ -517,7 +586,8 @@ def observations_from_map_data(
cols=feature_cols,
arm_name_only=arm_name_only,
map_keys=map_data.map_keys,
include_abandoned=include_abandoned,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
)
)
Expand All @@ -530,7 +600,8 @@ def observations_from_map_data(
cols=complete_feature_cols,
arm_name_only=arm_name_only,
map_keys=map_data.map_keys,
include_abandoned=include_abandoned,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
)
)
Expand Down
Loading

0 comments on commit f2b1dda

Please sign in to comment.