Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid errors in telemetry due to node-based GenerationStrategy #2554

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/telemetry/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AxClientCreatedRecord:

# Dimensionality of transformed SearchSpace can often be much higher due to one-hot
# encoding of unordered ChoiceParameters
transformed_dimensionality: int
transformed_dimensionality: Optional[int]

@classmethod
def from_ax_client(cls, ax_client: AxClient) -> AxClientCreatedRecord:
Expand Down
18 changes: 13 additions & 5 deletions ax/telemetry/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@

# pyre-strict

import warnings
from datetime import datetime
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

from ax.core.experiment import Experiment

from ax.exceptions.core import AxWarning
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy

from ax.modelbridge.modelbridge_utils import (
extract_search_space_digest,
transform_search_space,
)

from ax.modelbridge.registry import ModelRegistryBase, Models, SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast
Expand All @@ -32,11 +31,20 @@

def _get_max_transformed_dimensionality(
search_space: SearchSpace, generation_strategy: GenerationStrategy
) -> int:
) -> Optional[int]:
"""
Get dimensionality of transformed SearchSpace for all steps in the
GenerationStrategy and return the maximum.
"""
if generation_strategy.is_node_based:
warnings.warn(
"`_get_max_transformed_dimensionality` does not fully support node-based "
"generation strategies. This will result in an incomplete record.",
category=AxWarning,
stacklevel=4,
)
# TODO [T192965545]: Support node-based generation strategies in telemetry
return None

transforms_by_step = [
_extract_transforms_and_configs(step=step)
Expand Down
29 changes: 25 additions & 4 deletions ax/telemetry/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass
from math import inf
from typing import Optional

from ax.exceptions.core import AxWarning
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS

Expand All @@ -26,17 +29,35 @@ class GenerationStrategyCreatedRecord:
generation_strategy_name: str

# -1 indicates unlimited trials requested, 0 indicates no trials requested
num_requested_initialization_trials: int # Typically the number of Sobol trials
num_requested_bayesopt_trials: int
num_requested_other_trials: int
num_requested_initialization_trials: Optional[
int # Typically the number of Sobol trials
]
num_requested_bayesopt_trials: Optional[int]
num_requested_other_trials: Optional[int]

# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
max_parallelism: int
max_parallelism: Optional[int]

@classmethod
def from_generation_strategy(
cls, generation_strategy: GenerationStrategy
) -> GenerationStrategyCreatedRecord:
if generation_strategy.is_node_based:
warnings.warn(
"`GenerationStrategyCreatedRecord` does not fully support node-based "
"generation strategies. This will result in an incomplete record.",
category=AxWarning,
stacklevel=4,
)
# TODO [T192965545]: Support node-based generation strategies in telemetry
return cls(
generation_strategy_name=generation_strategy.name,
num_requested_initialization_trials=None,
num_requested_bayesopt_trials=None,
num_requested_other_trials=None,
max_parallelism=None,
)

# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
true_max_parallelism = min(
step.max_parallelism or inf for step in generation_strategy._steps
Expand Down
2 changes: 1 addition & 1 deletion ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class SchedulerCreatedRecord:

# Dimensionality of transformed SearchSpace can often be much higher due to one-hot
# encoding of unordered ChoiceParameters
transformed_dimensionality: int
transformed_dimensionality: Optional[int]

@classmethod
def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
Expand Down
24 changes: 23 additions & 1 deletion ax/telemetry/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

# pyre-strict

from ax.exceptions.core import AxWarning
from ax.telemetry.generation_strategy import GenerationStrategyCreatedRecord
from ax.utils.common.testutils import TestCase
from ax.utils.testing.modeling_stubs import get_generation_strategy
from ax.utils.testing.modeling_stubs import (
get_generation_strategy,
sobol_gpei_generation_node_gs,
)


class TestGenerationStrategy(TestCase):
Expand All @@ -25,3 +29,21 @@ def test_generation_strategy_created_record_from_generation_strategy(self) -> No
max_parallelism=3,
)
self.assertEqual(record, expected)

def test_generation_strategy_created_record_node_based(self) -> None:
gs = sobol_gpei_generation_node_gs()
with self.assertWarnsRegex(
AxWarning,
"`GenerationStrategyCreatedRecord` does not fully support node-based*",
):
record = GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=gs
)
expected = GenerationStrategyCreatedRecord(
generation_strategy_name="Sobol+GPEI_Nodes",
num_requested_initialization_trials=None,
num_requested_bayesopt_trials=None,
num_requested_other_trials=None,
max_parallelism=None,
)
self.assertEqual(record, expected)