Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Allow iterators in tune.grid_search #25220

Merged
merged 2 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions python/ray/tune/suggest/variant_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import logging
import re
from collections.abc import Mapping
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple

import numpy
import random

from ray.tune import TuneError
from ray.tune.sample import Categorical, Domain, Function, RandomState
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -58,13 +57,12 @@ def generate_variants(
yield resolved_vars, spec


def grid_search(values: List) -> Dict[str, List]:
def grid_search(values: Iterable) -> Dict[str, List]:
"""Convenience method for specifying grid search over a value.

Arguments:
values: An iterable whose parameters will be gridded.
"""

return {"grid_search": values}


Expand Down Expand Up @@ -406,10 +404,6 @@ def _try_resolve(v) -> Tuple[bool, Any]:
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
# Grid search values
grid_values = v["grid_search"]
if not isinstance(grid_values, list):
raise TuneError(
"Grid search expected list of values, got: {}".format(grid_values)
)
return False, Categorical(grid_values).grid()
return True, v

Expand Down
15 changes: 15 additions & 0 deletions python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,21 @@ def testPointsToEvaluateBasicVariantFixedParam(self):
("Pre-set value `2` is not equal to the value of parameter `a`: 1",),
)

def testGridSearchGenerator(self):
from ray.tune.suggest.basic_variant import BasicVariantGenerator

searcher = BasicVariantGenerator(constant_grid_search=False)
exp = Experiment(
run=_mock_objective,
name="test",
config={"parameter": tune.grid_search(range(10))},
num_samples=1,
)
searcher.add_configurations(exp)

trials = [searcher.next_trial() for i in range(10)]
assert [t.config["parameter"] for t in trials] == list(range(10))

def testConstantGridSearchBasicVariant(self):
config = {
"grid": tune.grid_search([1, 2, 3]),
Expand Down