From e3736fc29e4a23043051e298ed727ce85dc53c82 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 12 Apr 2024 16:23:11 -0700 Subject: [PATCH] Typing improvements to `RangeParameter` (#2327) Summary: This commit improves the typing coverage for `RangeParameter` and its downstream applications, as well as getters and setters four its bounds that include bounds validation. Reviewed By: bernardbeckerman Differential Revision: D55805080 --- ax/core/batch_trial.py | 5 +- ax/core/parameter.py | 108 +++++++++--------- ax/core/parameter_constraint.py | 3 +- ax/core/search_space.py | 6 +- ax/core/tests/test_parameter.py | 4 +- ax/modelbridge/dispatch_utils.py | 2 +- .../transforms/int_range_to_choice.py | 13 ++- ax/modelbridge/transforms/int_to_float.py | 2 +- ax/plot/helper.py | 13 +-- ax/telemetry/experiment.py | 6 +- ax/utils/testing/core_stubs.py | 11 +- ax/utils/testing/modeling_stubs.py | 6 +- 12 files changed, 89 insertions(+), 90 deletions(-) diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 985e57ca6c7..3b75293400b 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -482,10 +482,7 @@ def is_factorial(self) -> bool: param_levels: DefaultDict[str, Dict[Union[str, float], int]] = defaultdict(dict) for arm in self.arms: for param_name, param_value in arm.parameters.items(): - # Expected `Union[float, str]` for 2nd anonymous parameter to call - # `dict.__setitem__` but got `Optional[Union[bool, float, str]]`. - # pyre-fixme[6]: Expected `Union[float, str]` for 1st param but got `... - param_levels[param_name][param_value] = 1 + param_levels[param_name][not_none(param_value)] = 1 param_cardinality = 1 for param_values in param_levels.values(): param_cardinality *= len(param_values) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 7b1c59422df..94d6104fa78 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -12,13 +12,14 @@ from copy import deepcopy from enum import Enum from math import inf -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import cast, Dict, List, Optional, Tuple, Type, Union from warnings import warn -from ax.core.types import TParamValue, TParamValueList +from ax.core.types import TNumeric, TParamValue, TParamValueList from ax.exceptions.core import AxWarning, UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.typeutils import not_none +from pyre_extensions import assert_is_instance # Tolerance for floating point comparisons. This is relatively permissive, # and allows for serializing at rather low numerical precision. @@ -81,7 +82,7 @@ def _get_parameter_type(python_type: Type) -> ParameterType: class Parameter(SortableBase, metaclass=ABCMeta): _is_fidelity: bool = False _name: str - _target_value: Optional[TParamValue] = None + _target_value: TParamValue = None _parameter_type: ParameterType def cast(self, value: TParamValue) -> TParamValue: @@ -125,7 +126,7 @@ def is_hierarchical(self) -> bool: ) @property - def target_value(self) -> Optional[TParamValue]: + def target_value(self) -> TParamValue: return self._target_value @property @@ -234,7 +235,7 @@ def __init__( logit_scale: bool = False, digits: Optional[int] = None, is_fidelity: bool = False, - target_value: Optional[TParamValue] = None, + target_value: TParamValue = None, ) -> None: """Initialize RangeParameter @@ -259,17 +260,16 @@ def __init__( ) self._name = name + if parameter_type not in (ParameterType.INT, ParameterType.FLOAT): + raise UserInputError("RangeParameter type must be int or float.") self._parameter_type = parameter_type self._digits = digits - # pyre-fixme[4]: Attribute must be annotated. - self._lower = self.cast(lower) - # pyre-fixme[4]: Attribute must be annotated. - self._upper = self.cast(upper) + self._lower: TNumeric = not_none(self.cast(lower)) + self._upper: TNumeric = not_none(self.cast(upper)) self._log_scale = log_scale self._logit_scale = logit_scale self._is_fidelity = is_fidelity - # pyre-fixme[4]: Attribute must be annotated. - self._target_value = self.cast(target_value) + self._target_value: Optional[TNumeric] = self.cast(target_value) self._validate_range_param( parameter_type=parameter_type, @@ -279,16 +279,16 @@ def __init__( logit_scale=logit_scale, ) - def cardinality(self) -> float: + def cardinality(self) -> TNumeric: if self.parameter_type == ParameterType.FLOAT: return inf - return self.upper - self.lower + 1 + return int(self.upper) - int(self.lower) + 1 def _validate_range_param( self, - lower: TParamValue, - upper: TParamValue, + lower: TNumeric, + upper: TNumeric, log_scale: bool, logit_scale: bool, parameter_type: Optional[ParameterType] = None, @@ -298,15 +298,13 @@ def _validate_range_param( ParameterType.FLOAT, ): raise UserInputError("RangeParameter type must be int or float.") - # pyre-fixme[58]: `>=` is not supported for operand types `Union[None, bool, - # float, int, str]` and `Union[None, bool, float, int, str]`. + + upper = float(upper) if lower >= upper: raise UserInputError( f"Upper bound of {self.name} must be strictly larger than lower." f"Got: ({lower}, {upper})." ) - # pyre-fixme[58]: `-` is not supported for operand types `Union[None, bool, - # float, int, str]` and `Union[None, bool, float, int, str]`. width: float = upper - lower if width < 100 * EPS: raise UserInputError( @@ -316,12 +314,8 @@ def _validate_range_param( ) if log_scale and logit_scale: raise UserInputError("Can't use both log and logit.") - # pyre-fixme[58]: `<=` is not supported for operand types `Union[None, bool, - # float, int, str]` and `int`. if log_scale and lower <= 0: raise UserInputError("Cannot take log when min <= 0.") - # pyre-fixme[58]: `<=` is not supported for operand types `Union[None, bool, - # float, int, str]` and `int`. if logit_scale and (lower <= 0 or upper >= 1): raise UserInputError("Logit requires lower > 0 and upper < 1") if not (self.is_valid_type(lower)) or not (self.is_valid_type(upper)): @@ -330,7 +324,7 @@ def _validate_range_param( ) @property - def upper(self) -> float: + def upper(self) -> TNumeric: """Upper bound of the parameter range. Value is cast to parameter type upon set and also validated @@ -338,8 +332,18 @@ def upper(self) -> float: """ return self._upper + @upper.setter + def upper(self, value: TNumeric) -> None: + self._validate_range_param( + lower=self.lower, + upper=value, + log_scale=self.log_scale, + logit_scale=self.logit_scale, + ) + self._upper = not_none(self.cast(value)) + @property - def lower(self) -> float: + def lower(self) -> TNumeric: """Lower bound of the parameter range. Value is cast to parameter type upon set and also validated @@ -347,6 +351,16 @@ def lower(self) -> float: """ return self._lower + @lower.setter + def lower(self, value: TNumeric) -> None: + self._validate_range_param( + lower=value, + upper=self.upper, + log_scale=self.log_scale, + logit_scale=self.logit_scale, + ) + self._lower = not_none(self.cast(value)) + @property def digits(self) -> Optional[int]: """Number of digits to round values to for float type. @@ -381,8 +395,8 @@ def update_range( if upper is None: upper = self._upper - cast_lower = self.cast(lower) - cast_upper = self.cast(upper) + cast_lower = not_none(self.cast(lower)) + cast_upper = not_none(self.cast(upper)) self._validate_range_param( lower=cast_lower, upper=cast_upper, @@ -397,10 +411,8 @@ def set_digits(self, digits: int) -> RangeParameter: self._digits = digits # Re-scale min and max to new digits definition - cast_lower = self.cast(self._lower) - cast_upper = self.cast(self._upper) - # pyre-fixme[58]: `>=` is not supported for operand types `Union[None, bool, - # float, int, str]` and `Union[None, bool, float, int, str]`. + cast_lower = not_none(self.cast(self._lower)) + cast_upper = not_none(self.cast(self._upper)) if cast_lower >= cast_upper: raise UserInputError( f"Lower bound {cast_lower} is >= upper bound {cast_upper}." @@ -451,9 +463,7 @@ def is_valid_type(self, value: TParamValue) -> bool: # This might have issues with ints > 2^24 if self.parameter_type is ParameterType.INT: - # pyre-fixme[6]: Expected `Union[_SupportsIndex, bytearray, bytes, str, - # typing.SupportsFloat]` for 1st param but got `Union[None, float, str]`. - return isinstance(value, int) or float(value).is_integer() + return isinstance(value, int) or float(not_none(value)).is_integer() return True def clone(self) -> RangeParameter: @@ -469,13 +479,12 @@ def clone(self) -> RangeParameter: target_value=self._target_value, ) - def cast(self, value: TParamValue) -> TParamValue: + def cast(self, value: TParamValue) -> Optional[TNumeric]: if value is None: return None if self.parameter_type is ParameterType.FLOAT and self._digits is not None: - # pyre-fixme[6]: Expected `None` for 2nd param but got `Optional[int]`. - return round(float(value), self._digits) - return self.python_type(value) + return round(float(value), not_none(self._digits)) + return assert_is_instance(self.python_type(value), TNumeric) def __repr__(self) -> str: ret_val = self._base_repr() @@ -526,7 +535,7 @@ def __init__( is_ordered: Optional[bool] = None, is_task: bool = False, is_fidelity: bool = False, - target_value: Optional[TParamValue] = None, + target_value: TParamValue = None, sort_values: Optional[bool] = None, dependents: Optional[Dict[TParamValue, List[str]]] = None, ) -> None: @@ -561,9 +570,8 @@ def __init__( stacklevel=2, ) values = list(dict_values) - self._values: List[TParamValue] = self._cast_values(values) - # pyre-fixme[4]: Attribute must be annotated. - self._is_ordered = ( + + self._is_ordered: bool = ( is_ordered if is_ordered is not None else self._get_default_bool_and_warn(param_string="is_ordered") @@ -575,11 +583,9 @@ def __init__( else self._get_default_bool_and_warn(param_string="sort_values") ) if self.sort_values: - # pyre-ignore[6]: values/self._values expects List[Union[None, bool, float, - # int, str]] but sorted() takes/returns - # List[Variable[_typeshed.SupportsLessThanT (bound to - # _typeshed.SupportsLessThan)]] - self._values = self._cast_values(sorted(values)) + values = cast(List[TParamValue], sorted([not_none(v) for v in values])) + self._values: List[TParamValue] = self._cast_values(values) + if dependents: for value in dependents: if value not in self.values: @@ -714,7 +720,7 @@ def __init__( parameter_type: ParameterType, value: TParamValue, is_fidelity: bool = False, - target_value: Optional[TParamValue] = None, + target_value: TParamValue = None, dependents: Optional[Dict[TParamValue, List[str]]] = None, ) -> None: """Initialize FixedParameter @@ -737,11 +743,9 @@ def __init__( self._name = name self._parameter_type = parameter_type - # pyre-fixme[4]: Attribute must be annotated. - self._value = self.cast(value) + self._value: TParamValue = self.cast(value) self._is_fidelity = is_fidelity - # pyre-fixme[4]: Attribute must be annotated. - self._target_value = self.cast(target_value) + self._target_value: TParamValue = self.cast(target_value) # NOTE: We don't need to check that dependent parameters actually exist as # that is done in `HierarchicalSearchSpace` constructor. if dependents: diff --git a/ax/core/parameter_constraint.py b/ax/core/parameter_constraint.py index f22a739db44..1c5f3f0f82f 100644 --- a/ax/core/parameter_constraint.py +++ b/ax/core/parameter_constraint.py @@ -21,11 +21,10 @@ class ParameterConstraint(SortableBase): Constraints are expressed using a map from parameter name to weight followed by a bound. - The constraint is satisfied if w * v <= b where: + The constraint is satisfied if sum_i(w_i * v_i) <= b where: w is the vector of parameter weights. v is a vector of parameter values. b is the specified bound. - * is the dot product operator. """ def __init__(self, constraint_dict: Dict[str, float], bound: float) -> None: diff --git a/ax/core/search_space.py b/ax/core/search_space.py index ce05386876b..b84a0d4f3cf 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -238,8 +238,7 @@ def check_membership( # parameter constraints only accept numeric parameters numerical_param_dict = { - # pyre-fixme[6]: Expected `typing.Union[...oat]` but got `unknown`. - name: float(value) + name: float(not_none(value)) for name, value in parameterization.items() if self.parameters[name].is_numeric } @@ -544,7 +543,8 @@ def flatten_observation_features( # that behavior was requested via the opt-in flag. warnings.warn( f"Cannot flatten observation features {obs_feats} as full " - "parameterization is not recorded in metadata." + "parameterization is not recorded in metadata.", + stacklevel=2, ) return obs_feats diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index 2cae278ad89..d9ea6227f12 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -6,6 +6,8 @@ # pyre-strict +from typing import cast, List + from ax.core.parameter import ( _get_parameter_type, ChoiceParameter, @@ -299,7 +301,7 @@ def test_Properties(self) -> None: ) self.assertTrue(int_param.is_ordered) self.assertListEqual( - int_param.values, sorted(int_param.values) # pyre-fixme[6] + int_param.values, sorted(cast(List[int], int_param.values)) ) float_param = ChoiceParameter( name="x", parameter_type=ParameterType.FLOAT, values=[1.5, 2.5, 3.5] diff --git a/ax/modelbridge/dispatch_utils.py b/ax/modelbridge/dispatch_utils.py index 548c5ae552f..21a6cb06f7f 100644 --- a/ax/modelbridge/dispatch_utils.py +++ b/ax/modelbridge/dispatch_utils.py @@ -189,7 +189,7 @@ def _suggest_gp_model( if parameter.parameter_type == ParameterType.FLOAT: all_range_parameters_are_discrete = False else: - num_param_discrete_values = int(parameter.upper - parameter.lower) + 1 + num_param_discrete_values = parameter.cardinality() num_possible_points *= num_param_discrete_values if should_enumerate_param: diff --git a/ax/modelbridge/transforms/int_range_to_choice.py b/ax/modelbridge/transforms/int_range_to_choice.py index c94b152e79d..04d07ac80d4 100644 --- a/ax/modelbridge/transforms/int_range_to_choice.py +++ b/ax/modelbridge/transforms/int_range_to_choice.py @@ -6,7 +6,8 @@ # pyre-strict -from typing import Dict, List, Optional, Set, TYPE_CHECKING +from numbers import Real +from typing import cast, Dict, List, Optional, Set, TYPE_CHECKING from ax.core.observation import Observation from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter @@ -35,14 +36,16 @@ def __init__( ) -> None: assert search_space is not None, "IntRangeToChoice requires search space" config = config or {} - self.max_choices: float = config.get("max_choices", float("inf")) # pyre-ignore + self.max_choices: float = float( + cast(Real, (config.get("max_choices", float("inf")))) + ) # Identify parameters that should be transformed self.transform_parameters: Set[str] = { p_name for p_name, p in search_space.parameters.items() if isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT - and p.upper - p.lower + 1 <= self.max_choices + and p.cardinality() <= self.max_choices } def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: @@ -52,9 +55,9 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: p_name in self.transform_parameters and isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT - and p.upper - p.lower + 1 <= self.max_choices + and p.cardinality() <= self.max_choices ): - values = list(range(p.lower, p.upper + 1)) # pyre-ignore + values = list(range(int(p.lower), int(p.upper) + 1)) target_value = ( None if p.target_value is None diff --git a/ax/modelbridge/transforms/int_to_float.py b/ax/modelbridge/transforms/int_to_float.py index 9b1bf746581..ccdaadb3c4c 100644 --- a/ax/modelbridge/transforms/int_to_float.py +++ b/ax/modelbridge/transforms/int_to_float.py @@ -69,7 +69,7 @@ def __init__( for p_name, p in self.search_space.parameters.items() if isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT - and (p.upper - p.lower + 1 >= self.min_choices or p.log_scale) + and ((p.cardinality() >= self.min_choices) or p.log_scale) } if contains_constrained_integer(self.search_space, self.transform_parameters): self.rounding = "randomized" diff --git a/ax/plot/helper.py b/ax/plot/helper.py index e8a5cef2a4f..f49b07087ff 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -15,13 +15,7 @@ import numpy as np from ax.core.generator_run import GeneratorRun from ax.core.observation import Observation, ObservationFeatures -from ax.core.parameter import ( - ChoiceParameter, - FixedParameter, - Parameter, - ParameterType, - RangeParameter, -) +from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, RangeParameter from ax.core.types import TParameterization from ax.modelbridge.base import ModelBridge from ax.modelbridge.prediction_utils import ( @@ -444,10 +438,7 @@ def get_range_parameters_from_list( parameter for parameter in parameters if isinstance(parameter, RangeParameter) - and ( - parameter.parameter_type == ParameterType.FLOAT - or parameter.upper - parameter.lower + 1 >= min_num_values - ) + and parameter.cardinality() >= min_num_values # float has inf cardinality ] diff --git a/ax/telemetry/experiment.py b/ax/telemetry/experiment.py index 637e8ff06ab..eaba3902d04 100644 --- a/ax/telemetry/experiment.py +++ b/ax/telemetry/experiment.py @@ -181,7 +181,7 @@ def _get_param_counts_from_search_space( isinstance(param, RangeParameter) or (isinstance(param, ChoiceParameter) and param.is_ordered) ) - and (1 < param.cardinality() <= 3) + and (1.0 < param.cardinality() <= 3.0) ) num_int_range_parameters_medium = sum( 1 @@ -190,7 +190,7 @@ def _get_param_counts_from_search_space( isinstance(param, RangeParameter) or (isinstance(param, ChoiceParameter) and param.is_ordered) ) - and (3 < param.cardinality() <= 7) + and (3.0 < param.cardinality() <= 7.0) ) num_int_range_parameters_large = sum( 1 @@ -199,7 +199,7 @@ def _get_param_counts_from_search_space( isinstance(param, RangeParameter) or (isinstance(param, ChoiceParameter) and param.is_ordered) ) - and (7 < param.cardinality() < inf) + and (7.0 < param.cardinality() < inf) ) num_log_scale_range_parameters = sum( 1 diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 8ce79c9ea33..37343452497 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -1978,8 +1978,8 @@ def get_branin_data( "metric_name": "branin", "arm_name": not_none(checked_cast(Trial, trial).arm).name, "mean": branin( - float(not_none(trial.arm).parameters["x1"]), # pyre-ignore[6] - float(not_none(trial.arm).parameters["x2"]), # pyre-ignore[6] + float(not_none(not_none(trial.arm).parameters["x1"])), + float(not_none(not_none(trial.arm).parameters["x2"])), ), "sem": 0.0, } @@ -2007,9 +2007,10 @@ def get_branin_data_batch(batch: BatchTrial) -> Data: "arm_name": [arm.name for arm in batch.arms], "metric_name": "branin", "mean": [ - # pyre-ignore[6]: This function can fail if a parameter value - # does not support conversion to float. - branin(float(arm.parameters["x1"]), float(arm.parameters["x2"])) + branin( + float(not_none(arm.parameters["x1"])), + float(not_none(arm.parameters["x2"])), + ) for arm in batch.arms ], "sem": 0.1, diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index b070ac6876a..89efeb06e48 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -506,7 +506,8 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: new_ss = search_space.clone() for param in new_ss.parameters.values(): if isinstance(param, FixedParameter): - param._value += 1.0 + if param._value is not None and not isinstance(param._value, str): + param._value += 1.0 elif isinstance(param, RangeParameter): param._lower += 1.0 param._upper += 1.0 @@ -564,7 +565,8 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: new_ss = search_space.clone() for param in new_ss.parameters.values(): if isinstance(param, FixedParameter): - param._value *= 2.0 + if param._value is not None and not isinstance(param._value, str): + param._value *= 2.0 elif isinstance(param, RangeParameter): param._lower *= 2.0 param._upper *= 2.0