Skip to content

Commit

Permalink
simplify grid build
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 3, 2015
1 parent f4af925 commit b7a7b9b
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
#

import itertools

__all__ = ['ParamGridBuilder']


Expand Down Expand Up @@ -76,17 +78,9 @@ def build(self):
Builds and returns all combinations of parameters specified
by the param grid.
"""
param_maps = [{}]
for (param, values) in self._param_grid.items():
new_param_maps = []
for value in values:
for old_map in param_maps:
copied_map = old_map.copy()
copied_map[param] = value
new_param_maps.append(copied_map)
param_maps = new_param_maps

return param_maps
keys = self._param_grid.keys()
grid_values = self._param_grid.values()
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]


if __name__ == "__main__":
Expand Down

0 comments on commit b7a7b9b

Please sign in to comment.