Skip to content

Commit

Permalink
Minor changes to support HSS in telemetry (facebook#1654)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebook#1654

Reviewed By: dme65

Differential Revision: D46614741

fbshipit-source-id: 6b13352fb1f8d9fdbc1ab45b462b5a49469d3b84
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 12, 2023
1 parent 516ad32 commit e9989e2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def height(self) -> int:
"""

def _height_from_parameter(parameter: Parameter) -> int:
if len(parameter.dependents) == 0:
if not parameter.is_hierarchical:
return 1

return (
Expand Down
3 changes: 2 additions & 1 deletion ax/telemetry/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ax.modelbridge.registry import ModelRegistryBase, Models, SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast

# Models whose generated trails will count towards initialization_trials
INITIALIZATION_MODELS: List[Models] = [Models.SOBOL, Models.UNIFORM]
Expand All @@ -38,7 +39,7 @@ def _get_max_transformed_dimensionality(
transformed_search_spaces = [
transform_search_space(
search_space=search_space,
transforms=transforms,
transforms=[Cast] + transforms,
transform_configs=transform_configs,
)
for transforms, transform_configs in transforms_by_step
Expand Down
58 changes: 58 additions & 0 deletions ax/telemetry/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Sequence, Union

from ax.core.types import TParamValue
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.telemetry.ax_client import AxClientCompletedRecord, AxClientCreatedRecord
from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord
Expand Down Expand Up @@ -42,6 +45,61 @@ def test_ax_client_created_record_from_ax_client(self) -> None:
)
self.assertEqual(record, expected)

# Test with HSS & MOO.
ax_client = AxClient()
parameters: List[
Dict[str, Union[TParamValue, Sequence[TParamValue], Dict[str, List[str]]]]
] = [
{
"name": "SearchSpace.optimizer",
"type": "choice",
"values": ["Adam", "SGD", "Adagrad"],
"dependents": None,
"is_ordered": False,
},
{"name": "SearchSpace.lr", "type": "range", "bounds": [0.001, 0.1]},
{"name": "SearchSpace.fixed", "type": "fixed", "value": 12.0},
{
"name": "SearchSpace",
"type": "fixed",
"value": "SearchSpace",
"dependents": {
"SearchSpace": [
"SearchSpace.optimizer",
"SearchSpace.lr",
"SearchSpace.fixed",
]
},
},
]
ax_client.create_experiment(
name="hss_experiment",
parameters=parameters,
objectives={
"branin": ObjectiveProperties(minimize=True),
"b2": ObjectiveProperties(minimize=False),
},
is_test=True,
)
record = AxClientCreatedRecord.from_ax_client(ax_client=ax_client)

expected = AxClientCreatedRecord(
experiment_created_record=ExperimentCreatedRecord.from_experiment(
experiment=ax_client.experiment
),
generation_strategy_created_record=(
GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=ax_client.generation_strategy
)
),
arms_per_trial=1,
early_stopping_strategy_cls=None,
global_stopping_strategy_cls=None,
transformed_dimensionality=4,
)
self.assertEqual(record, expected)
self.assertEqual(record.experiment_created_record.hierarchical_tree_height, 2)

def test_ax_client_completed_record_from_ax_client(self) -> None:
ax_client = AxClient()
ax_client.create_experiment(
Expand Down

0 comments on commit e9989e2

Please sign in to comment.