Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init_position argument to UniformGenerator #2686

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class RandomModel(Model):
of the model will not return the same point twice. This flag
is used in rejection sampling.
seed: An optional seed value for scrambling.
init_position: The initial state of the generator. This is the number
of samples to fast-forward before generating new samples.
Used to ensure that the re-loaded generator will continue generating
from the same sequence rather than starting from scratch.
generated_points: A set of previously generated points to use
for deduplication. These should be provided in the raw transformed
space the model operates in.
Expand All @@ -59,6 +63,7 @@ def __init__(
self,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
generated_points: Optional[np.ndarray] = None,
fallback_to_sample_polytope: bool = False,
) -> None:
Expand All @@ -69,6 +74,7 @@ def __init__(
if seed is not None
else checked_cast(int, torch.randint(high=100_000, size=(1,)).item())
)
self.init_position = init_position
# Used for deduplication.
self.generated_points = generated_points
self.fallback_to_sample_polytope = fallback_to_sample_polytope
Expand Down Expand Up @@ -180,7 +186,13 @@ def gen(
@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
state.update({"seed": self.seed, "generated_points": self.generated_points})
state.update(
{
"seed": self.seed,
"init_position": self.init_position,
"generated_points": self.generated_points,
}
)
return state

def _gen_unconstrained(
Expand Down
16 changes: 3 additions & 13 deletions ax/models/random/sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

# pyre-strict

from typing import Any, Callable, Optional
from typing import Callable, Optional

import numpy as np
import torch
from ax.models.base import Model
from ax.models.model_utils import tunable_feature_indices
from ax.models.random.base import RandomModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import not_none
from torch.quasirandom import SobolEngine

Expand All @@ -26,17 +24,15 @@ class SobolGenerator(RandomModel):
the fit or predict methods.
Attributes:
init_position: The initial state of the Sobol generator.
Starts at 0 by default.
scramble: If True, permutes the parameter values among
the elements of the Sobol sequence. Default is True.
See base `RandomModel` for a description of remaining attributes.
"""

def __init__(
self,
seed: Optional[int] = None,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
scramble: bool = True,
generated_points: Optional[np.ndarray] = None,
Expand All @@ -45,10 +41,10 @@ def __init__(
super().__init__(
deduplicate=deduplicate,
seed=seed,
init_position=init_position,
generated_points=generated_points,
fallback_to_sample_polytope=fallback_to_sample_polytope,
)
self.init_position = init_position
self.scramble = scramble
# Initialize engine on gen.
self._engine: Optional[SobolEngine] = None
Expand Down Expand Up @@ -121,12 +117,6 @@ def gen(
self.init_position = not_none(self.engine).num_generated
return (points, weights)

@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
state.update({"init_position": self.init_position})
return state

def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
"""Generate n samples.
Expand Down
9 changes: 7 additions & 2 deletions ax/models/random/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
from ax.models.random.base import RandomModel
from scipy.stats import uniform


class UniformGenerator(RandomModel):
Expand All @@ -26,16 +25,21 @@ def __init__(
self,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
generated_points: Optional[np.ndarray] = None,
fallback_to_sample_polytope: bool = False,
) -> None:
super().__init__(
deduplicate=deduplicate,
seed=seed,
init_position=init_position,
generated_points=generated_points,
fallback_to_sample_polytope=fallback_to_sample_polytope,
)
self._rs = np.random.RandomState(seed=self.seed)
if self.init_position > 0:
# Fast-forward the random state by generating & discarding samples.
self._rs.uniform(size=(self.init_position))

def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
"""Generate samples from the scipy uniform distribution.
Expand All @@ -48,4 +52,5 @@ def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
samples: An (n x d) array of random points.
"""
return uniform.rvs(size=(n, tunable_d), random_state=self._rs) # pyre-ignore
self.init_position += n * tunable_d
return self._rs.uniform(size=(n, tunable_d))
105 changes: 48 additions & 57 deletions ax/models/tests/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,34 @@
class UniformGeneratorTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.tunable_param_bounds = (0, 1)
self.fixed_param_bounds = (1, 100)
self.tunable_param_bounds = (0.0, 1.0)
self.fixed_param_bounds = (1.0, 100.0)
self.seed = 0
self.expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338],
[0.54488318, 0.4236548, 0.64589411],
[0.43758721, 0.891773, 0.96366276],
]
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _create_bounds(self, n_tunable, n_fixed):
def _create_bounds(self, n_tunable: int, n_fixed: int) -> list[tuple[float, float]]:
tunable_bounds = [self.tunable_param_bounds] * n_tunable
fixed_bounds = [self.fixed_param_bounds] * n_fixed
return tunable_bounds + fixed_bounds

def test_UniformGeneratorAllTunable(self) -> None:
generator = UniformGenerator(seed=0)
def test_with_all_tunable(self) -> None:
generator = UniformGenerator(seed=self.seed)
bounds = self._create_bounds(n_tunable=3, n_fixed=0)
generated_points, weights = generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)

expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338],
[0.54488318, 0.4236548, 0.64589411],
[0.43758721, 0.891773, 0.96366276],
]
)
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))
self.assertTrue(np.shape(self.expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(self.expected_points, generated_points))
self.assertTrue(np.all(weights == 1.0))

def test_UniformGeneratorFixedSpace(self) -> None:
generator = UniformGenerator(seed=0)
def test_with_fixed_space(self) -> None:
generator = UniformGenerator(seed=self.seed)
bounds = self._create_bounds(n_tunable=0, n_fixed=2)
n = 3
with self.assertRaises(SearchSpaceExhausted):
Expand All @@ -55,7 +53,7 @@ def test_UniformGeneratorFixedSpace(self) -> None:
fixed_features={0: 1, 1: 2},
rounding_func=lambda x: x,
)
generator = UniformGenerator(seed=0, deduplicate=False)
generator = UniformGenerator(seed=self.seed, deduplicate=False)
generated_points, _ = generator.gen(
n=3,
bounds=bounds,
Expand All @@ -66,57 +64,50 @@ def test_UniformGeneratorFixedSpace(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorOnline(self) -> None:
def test_generating_one_by_one(self, init_position: int = 0) -> None:
# Verify that the generator will return the expected arms if called
# one at a time.
generator = UniformGenerator(seed=0)
generator = UniformGenerator(seed=self.seed, init_position=init_position)
n_tunable = fixed_param_index = 3
bounds = self._create_bounds(n_tunable=n_tunable, n_fixed=1)

n = 3
expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338, 1],
[0.54488318, 0.4236548, 0.64589411, 1],
[0.43758721, 0.891773, 0.96366276, 1],
]
)
for i in range(n):
for i in range(init_position, 3):
generated_points, weights = generator.gen(
n=1,
bounds=bounds,
fixed_features={fixed_param_index: 1},
rounding_func=lambda x: x,
)
self.assertEqual(weights, [1])
self.assertTrue(np.allclose(generated_points, expected_points[i, :]))
self.assertTrue(
np.allclose(generated_points[..., :-1], self.expected_points[i, :])
)
self.assertEqual(generated_points[..., -1], 1)
self.assertEqual(generator.init_position, (i + 1) * n_tunable)

def test_UniformGeneratorReseed(self) -> None:
# Verify that the generator will return the expected arms if called
# one at a time.
generator = UniformGenerator(seed=0)
n_tunable = fixed_param_index = 3
bounds = self._create_bounds(n_tunable=n_tunable, n_fixed=1)
def test_with_init_position(self) -> None:
# These are multiples of 3 since there are 3 tunable parameters.
self.test_generating_one_by_one(init_position=3)
self.test_generating_one_by_one(init_position=6)

n = 3
expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338, 1],
[0.54488318, 0.4236548, 0.64589411, 1],
[0.43758721, 0.891773, 0.96366276, 1],
]
def test_with_reloaded_state(self) -> None:
# Check that a reloaded generator will produce the same samples.
org_generator = UniformGenerator()
bounds = self._create_bounds(n_tunable=3, n_fixed=0)
# Generate some to advance the state.
org_generator.gen(n=3, bounds=bounds, rounding_func=lambda x: x)
# Construct a new generator with the state.
new_generator = UniformGenerator(**org_generator._get_state())
# Compare the generated samples.
org_samples, _ = org_generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)
for i in range(n):
generated_points, weights = generator.gen(
n=1,
bounds=bounds,
fixed_features={fixed_param_index: 1},
rounding_func=lambda x: x,
)
self.assertEqual(weights, [1])
self.assertTrue(np.allclose(generated_points, expected_points[i, :]))
new_samples, _ = new_generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)
self.assertTrue(np.allclose(org_samples, new_samples))

def test_UniformGeneratorWithOrderConstraints(self) -> None:
def test_with_order_constraints(self) -> None:
# Enforce dim_0 <= dim_1 <= dim_2 <= dim_3.
# Enforce both fixed and tunable constraints.
generator = UniformGenerator(seed=0)
Expand All @@ -143,7 +134,7 @@ def test_UniformGeneratorWithOrderConstraints(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorWithLinearConstraints(self) -> None:
def test_with_linear_constraints(self) -> None:
# Enforce dim_0 <= dim_1 <= dim_2 <= dim_3.
# Enforce both fixed and tunable constraints.
generator = UniformGenerator(seed=0)
Expand All @@ -169,7 +160,7 @@ def test_UniformGeneratorWithLinearConstraints(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorBadBounds(self) -> None:
def test_with_bad_bounds(self) -> None:
generator = UniformGenerator()
with self.assertRaises(ValueError):
generated_points, weights = generator.gen(
Expand Down
2 changes: 2 additions & 0 deletions ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def is_ax_equal(one_val: Any, other_val: Any) -> bool:
dates, and numpy arrays. This method and ``same_elements`` function
as a recursive unit.
"""
if type(one_val) is not type(other_val):
return False
if isinstance(one_val, list) and isinstance(other_val, list):
return same_elements(one_val, other_val)
elif isinstance(one_val, dict) and isinstance(other_val, dict):
Expand Down
6 changes: 6 additions & 0 deletions ax/utils/common/tests/test_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
dataframe_equals,
datetime_equals,
equality_typechecker,
is_ax_equal,
object_attribute_dicts_find_unequal_fields,
same_elements,
)
Expand Down Expand Up @@ -81,3 +82,8 @@ def test_numpy_equals(self) -> None:
self.assertEqual(
object_attribute_dicts_find_unequal_fields(np_0, np_1), ({}, {})
)

def test_is_ax_equal_with_different_types(self) -> None:
self.assertFalse(is_ax_equal(1, np.random.random(5)))
self.assertFalse(is_ax_equal(1, np.ones(5)))
self.assertFalse(is_ax_equal(1, np.ones(1)))
Loading