Skip to content

Commit

Permalink
Remove status quo weight override from COPY_DB_IDS_ATTRS_TO_SKIP (fac…
Browse files Browse the repository at this point in the history
…ebook#2615)

Summary:
Pull Request resolved: facebook#2615

Encoder and decoder both deal with `_status_quo_weight_override`.  In decoder in happens through the status quo generator run.  The problem is when a trial has a `_status_quo_weight_override`, but no status quo.  The solution in this diff is to make it impossible (unless you use protected fields directly) to have a `_status_quo_weight_override` without a `status_quo`.

## How could this be wrong?
If the user needs to store a status quo weight override on the trial for later but does not yet have a status quo.  But I don't know why they could only calculate the weight now and not later.

Reviewed By: mgarrard

Differential Revision: D60413211
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jul 31, 2024
1 parent 8e07000 commit 9dff7f2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 26 deletions.
8 changes: 6 additions & 2 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def status_quo(self, status_quo: Optional[Arm]) -> None:
def unset_status_quo(self) -> None:
"""Set the status quo to None."""
self._status_quo = None
self._status_quo_weight_override = None
self._refresh_arms_by_name()

@immutable_once_run
Expand All @@ -362,8 +363,11 @@ def set_status_quo_with_weight(
result in the weight being additive over all generator runs.
"""
# Assign a name to this arm if none exists
if weight is not None and weight <= 0.0:
raise ValueError("Status quo weight must be positive.")
if weight is not None:
if weight <= 0.0:
raise ValueError("Status quo weight must be positive.")
if status_quo is None:
raise ValueError("Cannot set weight because status quo is not defined.")

if status_quo is not None:
self.experiment.search_space.check_types(
Expand Down
57 changes: 37 additions & 20 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def setUp(self) -> None:
weights = get_weights()
self.status_quo = arms[0]
self.sq_weight = weights[0]
self.new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})
self.arms = arms[1:]
self.weights = weights[1:]
self.batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
Expand Down Expand Up @@ -145,7 +146,6 @@ def test_InitWithGeneratorRun(self) -> None:
self.assertEqual(len(self.batch.generator_run_structs), 1)

def test_StatusQuoOverlap(self) -> None:
new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})
# Set status quo to existing arm
self.batch.set_status_quo_with_weight(self.arms[0], self.sq_weight)
# Status quo weight is set to the average of other arms' weights.
Expand All @@ -158,36 +158,40 @@ def test_StatusQuoOverlap(self) -> None:
self.assertEqual(sum(self.batch.weights), self.weights[1] + self.sq_weight)

# Set status quo to new arm, add it
self.batch.set_status_quo_with_weight(new_sq, self.sq_weight)
self.batch.set_status_quo_with_weight(self.new_sq, self.sq_weight)
self.assertEqual(self.batch.status_quo.name, "status_quo_0")
self.batch.add_arms_and_weights([new_sq])
self.batch.add_arms_and_weights([self.new_sq])
self.assertEqual(
self.batch.generator_run_structs[1].generator_run.arms[0].name,
"status_quo_0",
)

def test_StatusQuo(self) -> None:
tot_weight = sum(self.batch.weights)
new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})

# Test negative weight
def test_status_quo_cannot_have_negative_weight(self) -> None:
with self.assertRaises(ValueError):
self.batch.set_status_quo_with_weight(new_sq, -1)
self.batch.set_status_quo_with_weight(self.new_sq, -1)

def test_status_quo_cannot_be_set_directly(self) -> None:
# Test that directly setting the status quo raises an error
with self.assertRaises(NotImplementedError):
self.batch.status_quo = new_sq
self.batch.status_quo = self.new_sq

def test_status_quo_can_be_set_to_a_new_arm(self) -> None:
tot_weight = sum(self.batch.weights)
# Set status quo to new arm
self.batch.set_status_quo_with_weight(new_sq, self.sq_weight)
self.assertTrue(self.batch.status_quo == new_sq)
self.batch.set_status_quo_with_weight(self.new_sq, self.sq_weight)
self.assertTrue(self.batch.status_quo == self.new_sq)
self.assertEqual(self.batch.status_quo.name, "status_quo_0")
self.assertEqual(sum(self.batch.weights), tot_weight + self.sq_weight)

def test_status_quo_weight_is_ignored_when_none(self) -> None:
tot_weight = sum(self.batch.weights)
# sq weight should be ignored when sq is None
self.batch.unset_status_quo()
self.assertEqual(sum(self.batch.weights), tot_weight)
self.assertIsNone(self.batch.status_quo)
self.assertIsNone(self.batch._status_quo_weight_override)

# Verify experiment status quo gets set on init
def test_status_quo_set_on_clone(self) -> None:
self.experiment.status_quo = self.status_quo
batch2 = self.batch.clone()
self.assertEqual(batch2.status_quo, self.experiment.status_quo)
Expand All @@ -198,24 +202,30 @@ def test_StatusQuo(self) -> None:
self.assertTrue(batch2.status_quo not in batch2.arm_weights)
self.assertEqual(sum(batch2.weights), sum(self.weights))

# Try setting sq to existing arm with different name
def test_status_quo_cannot_be_set_with_different_name(self) -> None:
# Set status quo to new arm
self.batch.set_status_quo_with_weight(self.status_quo, self.sq_weight)
with self.assertRaises(ValueError):
self.batch.set_status_quo_with_weight(
Arm(new_sq.parameters, name="new_name"), 1
Arm(self.status_quo.parameters, name="new_name"), 1
)

def test_StatusQuoOptimizeForPower(self) -> None:
def test_cannot_optimizer_for_power_without_status_quo(self) -> None:
self.experiment.status_quo = None
with self.assertRaises(ValueError):
self.experiment.new_batch_trial(optimize_for_power=True)

def test_opt_for_power_sq_weight_is_one_for_empty_trial(self) -> None:
self.experiment.status_quo = self.status_quo
batch = self.experiment.new_batch_trial(optimize_for_power=True)
self.assertEqual(batch._status_quo_weight_override, 1)

self.experiment.status_quo = None
with self.assertRaises(ValueError):
batch = self.experiment.new_batch_trial(optimize_for_power=True)

batch.add_arms_and_weights(arms=[])
self.assertTrue(batch._status_quo_weight_override, 1)

def test_opt_for_power_sq_weight_is_sqrt_k(self) -> None:
self.experiment.status_quo = self.status_quo
batch = self.experiment.new_batch_trial(optimize_for_power=True)
batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
expected_status_quo_weight = math.sqrt(sum(self.weights))
self.assertTrue(
Expand All @@ -227,6 +237,13 @@ def test_StatusQuoOptimizeForPower(self) -> None:
)
)

def test_cannot_opt_for_power_without_status_quo(self) -> None:
self.experiment.status_quo = None
with self.assertRaisesRegex(
ValueError, "Can only optimize for power if experiment has a status quo."
):
self.experiment.new_batch_trial(optimize_for_power=True)

def test_ArmsByName(self) -> None:
# Initializes empty
newbatch = self.experiment.new_batch_trial()
Expand Down
4 changes: 0 additions & 4 deletions ax/storage/sqa_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@
"_steps",
"analysis_scheduler",
"_nodes",
# ``status_quo_weight_override`` is a field on ``BatchTrial`` not in the
# "trial_v2" table
# TODO(T193258337)
"_status_quo_weight_override",
}
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."

Expand Down

0 comments on commit 9dff7f2

Please sign in to comment.