Skip to content

Commit

Permalink
fix bug in UnitX when transforming new search space (#2639)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2639

This fixes a significant bug, where a new search space (other than the search space passed to `UnitX.__init__`) that is passed to `UnitX._transform_search_space` is not actually transformed.

A common setting where this comes up is passing a new search space to `Modelbridge.gen` to generate candidates from a particular part of the searchspace. E.g. if the original search space bounds for `x` were [2.0, 5.0] and I want to generate candidates from the restricted search space [3.0, 4.0] by passing a search space to `Modelbridge.gen`, the new search space previously ignored.

Reviewed By: mgarrard

Differential Revision: D60855920
  • Loading branch information
sdaulton authored and facebook-github-bot committed Aug 6, 2024
1 parent b015625 commit ac184d5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
89 changes: 89 additions & 0 deletions ax/modelbridge/transforms/tests/test_unit_x_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.transforms.unit_x import UnitX
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_robust_search_space


Expand Down Expand Up @@ -144,6 +145,94 @@ def test_TransformSearchSpace(self) -> None:
self.search_space_with_target.parameters["x"].target_value, 1.0
)

def test_TransformNewSearchSpace(self) -> None:
new_ss = SearchSpace(
parameters=[
RangeParameter(
"x", lower=1.5, upper=2.0, parameter_type=ParameterType.FLOAT
),
RangeParameter(
"y", lower=1.25, upper=2.0, parameter_type=ParameterType.FLOAT
),
RangeParameter(
"z",
lower=1.0,
upper=1.5,
parameter_type=ParameterType.FLOAT,
log_scale=True,
),
RangeParameter(
"a", lower=0.0, upper=2, parameter_type=ParameterType.INT
),
ChoiceParameter(
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
],
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": -0.5, "y": 1}, bound=0.5),
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
],
)
self.t.transform_search_space(new_ss)
# Parameters transformed
true_bounds = {
"x": [
0.25 * self.target_range + self.target_lb,
0.5 * self.target_range + self.target_lb,
],
"y": [
0.25 * self.target_range + self.target_lb,
1.0 * self.target_range + self.target_lb,
],
"z": [1.0, 1.5],
"a": [0, 2],
}
for p_name, (l, u) in true_bounds.items():
p = checked_cast(RangeParameter, new_ss.parameters[p_name])
self.assertEqual(p.lower, l)
self.assertEqual(p.upper, u)
self.assertEqual(
checked_cast(ChoiceParameter, new_ss.parameters["b"]).values,
["a", "b", "c"],
)
self.assertEqual(len(new_ss.parameters), 5)
# # Constraints transformed
self.assertEqual(
new_ss.parameter_constraints[0].constraint_dict, self.expected_c_dicts[0]
)
self.assertEqual(
new_ss.parameter_constraints[0].bound, self.expected_c_bounds[0]
)
self.assertEqual(
new_ss.parameter_constraints[1].constraint_dict, self.expected_c_dicts[1]
)
self.assertEqual(
new_ss.parameter_constraints[1].bound, self.expected_c_bounds[1]
)

# Test transform of target value
t = self.transform_class(
search_space=self.search_space_with_target,
observations=[],
)
new_search_space_with_target = SearchSpace(
parameters=[
RangeParameter(
"x",
lower=1,
upper=2,
parameter_type=ParameterType.FLOAT,
is_fidelity=True,
target_value=2,
)
]
)
t.transform_search_space(new_search_space_with_target)
self.assertEqual(
new_search_space_with_target.parameters["x"].target_value,
0.5 * self.target_range + self.target_lb,
)

def test_w_robust_search_space_univariate(self) -> None:
# Check that if no transforms are needed, it is untouched.
for multivariate in (True, False):
Expand Down
10 changes: 6 additions & 4 deletions ax/modelbridge/transforms/unit_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ def transform_observation_features(

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for p_name, p in search_space.parameters.items():
if p_name in self.bounds and isinstance(p, RangeParameter):
if (p_bounds := self.bounds.get(p_name)) is not None and isinstance(
p, RangeParameter
):
p.update_range(
lower=self.target_lb,
upper=self.target_lb + self.target_range,
lower=self._normalize_value(value=p.lower, bounds=p_bounds),
upper=self._normalize_value(value=p.upper, bounds=p_bounds),
)
if p.target_value is not None:
p._target_value = self._normalize_value(
p.target_value, self.bounds[p_name] # pyre-ignore[6]
value=p.target_value, bounds=p_bounds # pyre-ignore [6]
)
new_constraints: List[ParameterConstraint] = []
for c in search_space.parameter_constraints:
Expand Down

0 comments on commit ac184d5

Please sign in to comment.