Skip to content

Commit

Permalink
[tune] Allow iterators in tune.grid_search (#25220)
Browse files Browse the repository at this point in the history
`tune.choice` already accepts iterables, the same should be true for `tune.grid_search`.
  • Loading branch information
krfricke authored May 26, 2022
1 parent 7fcea8a commit d0dfac5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
10 changes: 2 additions & 8 deletions python/ray/tune/suggest/variant_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import logging
import re
from collections.abc import Mapping
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple

import numpy
import random

from ray.tune import TuneError
from ray.tune.sample import Categorical, Domain, Function, RandomState
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -58,13 +57,12 @@ def generate_variants(
yield resolved_vars, spec


def grid_search(values: List) -> Dict[str, List]:
def grid_search(values: Iterable) -> Dict[str, List]:
"""Convenience method for specifying grid search over a value.
Arguments:
values: An iterable whose parameters will be gridded.
"""

return {"grid_search": values}


Expand Down Expand Up @@ -406,10 +404,6 @@ def _try_resolve(v) -> Tuple[bool, Any]:
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
# Grid search values
grid_values = v["grid_search"]
if not isinstance(grid_values, list):
raise TuneError(
"Grid search expected list of values, got: {}".format(grid_values)
)
return False, Categorical(grid_values).grid()
return True, v

Expand Down
15 changes: 15 additions & 0 deletions python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,21 @@ def testPointsToEvaluateBasicVariantFixedParam(self):
("Pre-set value `2` is not equal to the value of parameter `a`: 1",),
)

def testGridSearchGenerator(self):
from ray.tune.suggest.basic_variant import BasicVariantGenerator

searcher = BasicVariantGenerator(constant_grid_search=False)
exp = Experiment(
run=_mock_objective,
name="test",
config={"parameter": tune.grid_search(range(10))},
num_samples=1,
)
searcher.add_configurations(exp)

trials = [searcher.next_trial() for i in range(10)]
assert [t.config["parameter"] for t in trials] == list(range(10))

def testConstantGridSearchBasicVariant(self):
config = {
"grid": tune.grid_search([1, 2, 3]),
Expand Down

0 comments on commit d0dfac5

Please sign in to comment.