Skip to content

Commit

Permalink
Transform batch to new sq (#2755)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2755

Implement transform for scaling batches to the same SQ.

Reviewed By: ItsMrLin

Differential Revision: D62266853
  • Loading branch information
sdaulton authored and facebook-github-bot committed Sep 9, 2024
1 parent 4b54cc5 commit fbcaba7
Show file tree
Hide file tree
Showing 5 changed files with 503 additions and 101 deletions.
134 changes: 79 additions & 55 deletions ax/modelbridge/transforms/relativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from abc import ABC, abstractmethod

from math import sqrt
from typing import Callable, Optional, TYPE_CHECKING
from typing import Callable, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
Expand Down Expand Up @@ -47,16 +47,15 @@ class BaseRelativize(Transform, ABC):
appropriate transform/untransform differently.
"""

MISSING_STATUS_QUO_ERROR = "Cannot relativize data without status quo data"

def __init__(
self,
search_space: Optional[SearchSpace] = None,
observations: Optional[list[Observation]] = None,
modelbridge: Optional[modelbridge_module.base.ModelBridge] = None,
config: Optional[TConfig] = None,
) -> None:
assert observations is not None, "Relativize requires observations"
cls_name = self.__class__.__name__
assert observations is not None, f"{cls_name} requires observations"
super().__init__(
search_space=search_space,
observations=observations,
Expand All @@ -65,9 +64,18 @@ def __init__(
)
# self.modelbridge should NOT be modified
self.modelbridge: ModelBridge = not_none(
modelbridge, "Relativize transform requires a modelbridge"
modelbridge, f"{cls_name} transform requires a modelbridge"
)

self.status_quo_data_by_trial: dict[int, ObservationData] = not_none(
self.modelbridge.status_quo_data_by_trial,
f"{cls_name} requires status quo data.",
)
# use latest index of latest observed trial by default
# to handle pending trials, which may not have a trial_index
# if TrialAsTask was not used to generate the trial.
self.default_trial_idx: int = max(self.status_quo_data_by_trial.keys())

@property
@abstractmethod
def control_as_constant(self) -> bool:
Expand Down Expand Up @@ -158,42 +166,36 @@ def untransform_observations(
observations=observations, rel_op=unrelativize
)

def _rel_op_on_observations(
def _get_relative_data_from_obs(
self,
observations: list[Observation],
obs: Observation,
rel_op: Callable[..., tuple[np.ndarray, np.ndarray]],
) -> list[Observation]:

sq_data_by_trial: dict[int, ObservationData] = not_none(
self.modelbridge.status_quo_data_by_trial, self.MISSING_STATUS_QUO_ERROR
) -> ObservationData:
idx = (
int(obs.features.trial_index)
if obs.features.trial_index is not None
else self.default_trial_idx
)

# use latest index of latest observed trial by default
# to handle pending trials, which may not have a trial_index
# if TrialAsTask was not used to generate the trial.
default_trial_idx: int = max(sq_data_by_trial.keys())

def _get_relative_data_from_obs(
obs: Observation,
rel_op: Callable[..., tuple[np.ndarray, np.ndarray]],
) -> ObservationData:
idx = (
int(obs.features.trial_index)
if obs.features.trial_index is not None
else default_trial_idx
)
if idx not in sq_data_by_trial:
raise ValueError(self.MISSING_STATUS_QUO_ERROR)
return self._get_relative_data(
data=obs.data,
status_quo_data=sq_data_by_trial[idx],
rel_op=rel_op,
if idx not in self.status_quo_data_by_trial:
raise ValueError(
f"{self.__class__.__name__} requires status quo data for trial "
f"index {idx}."
)
return self._get_relative_data(
data=obs.data,
status_quo_data=self.status_quo_data_by_trial[idx],
rel_op=rel_op,
)

def _rel_op_on_observations(
self,
observations: list[Observation],
rel_op: Callable[..., tuple[np.ndarray, np.ndarray]],
) -> list[Observation]:
return [
Observation(
features=obs.features,
data=_get_relative_data_from_obs(obs, rel_op),
data=self._get_relative_data_from_obs(obs, rel_op),
arm_name=obs.arm_name,
)
for obs in observations
Expand Down Expand Up @@ -225,37 +227,59 @@ def _get_relative_data(
covariance=np.zeros((L, L)),
)
for i, metric in enumerate(data.metric_names):
try:
j = next(
k for k in range(L) if status_quo_data.metric_names[k] == metric
)
except (IndexError, StopIteration):
raise ValueError(
"Relativization cannot be performed because "
"ObservationData for status quo is missing metrics"
)

j = get_metric_index(data=status_quo_data, metric_name=metric)
means_t = data.means[i]
sems_t = sqrt(data.covariance[i][i])
mean_c = status_quo_data.means[j]
sem_c = sqrt(status_quo_data.covariance[j][j])

# if the is the status quo
if means_t == mean_c and sems_t == sem_c:
means_rel, sems_rel = 0, 0
else:
means_rel, sems_rel = rel_op(
means_t=means_t,
sems_t=sems_t,
mean_c=mean_c,
sem_c=sem_c,
as_percent=True,
control_as_constant=self.control_as_constant,
)
means_rel, sems_rel = self._get_rel_mean_sem(
means_t=means_t,
sems_t=sems_t,
mean_c=mean_c,
sem_c=sem_c,
metric=metric,
rel_op=rel_op,
)
result.means[i] = means_rel
result.covariance[i][i] = sems_rel**2
return result

def _get_rel_mean_sem(
self,
means_t: float,
sems_t: float,
mean_c: float,
sem_c: float,
metric: str,
rel_op: Callable[..., tuple[np.ndarray, np.ndarray]],
) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray]]:
"""Compute (un)relativized mean and sem for a single metric."""
# if the is the status quo
if means_t == mean_c and sems_t == sem_c:
return 0, 0
return rel_op(
means_t=means_t,
sems_t=sems_t,
mean_c=mean_c,
sem_c=sem_c,
as_percent=True,
control_as_constant=self.control_as_constant,
)


def get_metric_index(data: ObservationData, metric_name: str) -> int:
"""Get the index of a metric in the ObservationData."""
try:
return next(
k for k, name in enumerate(data.metric_names) if name == metric_name
)
except (IndexError, StopIteration):
raise ValueError(
"Relativization cannot be performed because "
"ObservationData for status quo is missing metrics"
)


class Relativize(BaseRelativize):
"""
Expand Down
Loading

0 comments on commit fbcaba7

Please sign in to comment.