Skip to content

Commit

Permalink
[Tune] Fix AxSearch save and nan/inf result handling (#31147)
Browse files Browse the repository at this point in the history
This PR fixes AxSearch saving and handles trials that produce nan/inf metrics properly.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored and AmeerHajAli committed Jan 12, 2023
1 parent 4ef392d commit ea11fd5
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 141 deletions.
24 changes: 17 additions & 7 deletions python/ray/tune/search/ax/ax_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import pickle
import numpy as np
from typing import Dict, List, Optional, Union

from ray import cloudpickle
from ray.tune.result import DEFAULT_METRIC
from ray.tune.search.sample import (
Categorical,
Expand Down Expand Up @@ -151,7 +152,7 @@ def __init__(
parameter_constraints: Optional[List] = None,
outcome_constraints: Optional[List] = None,
ax_client: Optional[AxClient] = None,
**ax_kwargs
**ax_kwargs,
):
assert (
ax is not None
Expand Down Expand Up @@ -324,12 +325,21 @@ def on_trial_complete(self, trial_id, result=None, error=False):

def _process_result(self, trial_id, result):
ax_trial_index = self._live_trial_mapping[trial_id]
metric_dict = {self._metric: (result[self._metric], None)}
outcome_names = [
metrics_to_include = [self._metric] + [
oc.metric.name
for oc in self._ax.experiment.optimization_config.outcome_constraints
]
metric_dict.update({on: (result[on], None) for on in outcome_names})
metric_dict = {}
for key in metrics_to_include:
val = result[key]
if np.isnan(val) or np.isinf(val):
# Don't report trials with NaN metrics to Ax
self._ax.abandon_trial(
trial_index=ax_trial_index,
reason=f"nan/inf metrics reported by {trial_id}",
)
return
metric_dict[key] = (val, None)
self._ax.complete_trial(trial_index=ax_trial_index, raw_data=metric_dict)

@staticmethod
Expand Down Expand Up @@ -415,9 +425,9 @@ def resolve_value(par, domain):
def save(self, checkpoint_path: str):
save_object = self.__dict__
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile)
cloudpickle.dump(save_object, outputFile)

def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
save_object = pickle.load(inputFile)
save_object = cloudpickle.load(inputFile)
self.__dict__.update(save_object)
Loading

0 comments on commit ea11fd5

Please sign in to comment.