-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tune] Pause trials scheduled with the ASHA scheduler when they reach a milestone #32634
Comments
hmmm, is it necessarily better to use more GPUs towards the end of trials? To fully utilize these GPUs, one would increase global batch size which may impact model performance especially towards the end. As you mentioned, lr or other hyperparameters may need to be co-adjusted. |
This request is about adding functionality to preempt trials that have reached a milestone in favor of those which have yet to reach the same milestone. |
sure, but how would the remaining trials (that have not yet reached a certain milestone) leverage the additional GPUs that get allocated to them? Do those trials need to adjust other hyperparameters in light of a bigger batch size? What is your proposal here? |
some pseudo code would be helpful for me to understand the intended flow and the gap here. |
Hey @MahdiNazemi, please find the code below: This is code is a bit old but it should still work fine. It should provide the same implementation as in the paper. We may want to integrate that into Ray Tune if there's demand |
@xwjiang2010, sorry if my initial explanation of the use case was not clear enough. Here are the steps I have in mind:
Because each trial is allocated all resources on a node, the GPUs will be fully utilized regardless of whether we are at the beginning or near the end of our search. Here is how ASHA can be set to work now: Case 1: use all GPUs per trial Case 2: use a non-trivial subset of GPUs per trial
|
@Yard1, thanks for sharing the The following line prints the exact same output to logger.info(f"Choosing trial {trial.config} to run from trialrunner.") Does it mean |
I believe that means that it wants to run the trial but it cannot due to eg. all resources being claimed. You should be able to simply remove that code line. |
The aforesaid line of code was printing the configuration of the first running trial over and over again. I was able to suppress the messages by using a Here is an updated version of the code you shared with some import logging
from typing import Dict, List, Optional, Union, Generator
import numpy as np
from ray.tune.execution import trial_runner
from ray.tune.experiment import Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.async_hyperband import _Bracket
logger = logging.getLogger(__name__)
class ASHAv2(AsyncHyperBandScheduler):
"""Implements the Async Successive Halving with better termination.
This should provide similar theoretical performance as HyperBand but
avoid straggler issues that HyperBand faces. One implementation detail
is when using multiple brackets, trial allocation to bracket is done
randomly with over a softmax probability.
See https://arxiv.org/abs/1810.05934
Args:
time_attr (str): A training result attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
metric (str): The training result objective value attribute. Stopping
procedures will use this attribute. If None but a mode was passed,
the `ray.tune.result.DEFAULT_METRIC` will be used per default.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
max_t (float): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
reduction_factor (float): Used to set halving rate and amount. This
is simply a unit-less scalar.
brackets (int): Number of brackets. Each bracket has a different
halving rate, specified by the reduction factor.
"""
def __init__(
self,
time_attr: str = "training_iteration",
metric: Optional[str] = None,
mode: Optional[str] = None,
max_t: int = 100,
grace_period: int = 1,
reduction_factor: float = 4,
brackets: int = 1,
):
super().__init__(
time_attr=time_attr,
metric=metric,
mode=mode,
max_t=max_t,
grace_period=grace_period,
reduction_factor=reduction_factor,
brackets=brackets,
)
# Tracks state for new trial add
self._brackets = [
_BracketV2(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._num_paused = 0
self._trial_info: Dict[str, _BracketV2] = {} # Stores Trial -> Bracket
def on_trial_result(
self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict
) -> str:
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
if result[self._time_attr] >= self._max_t:
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
result[self._time_attr],
self._metric_op * result[self._metric],
)
if action == TrialScheduler.STOP:
self._num_stopped += 1
if action == TrialScheduler.PAUSE:
self._num_paused += 1
return action
def on_trial_complete(
self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict
):
if self._time_attr not in result or self._metric not in result:
return
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
result[self._time_attr],
self._metric_op * result[self._metric],
complete=True,
)
del self._trial_info[trial.trial_id]
def choose_trial_to_run(self, trial_runner: "trial_runner.TrialRunner") -> Trial:
for bracket in self._brackets:
for trial in bracket.promotable_trials():
if trial and trial_runner.trial_executor.has_resources_for_trial(trial):
assert trial.status == Trial.PAUSED
logger.warning(f"Promoting trial [{trial.config}].")
bracket.unpause_trial(trial)
return trial
trial = FIFOScheduler.choose_trial_to_run(self, trial_runner)
if trial:
self._trial_info[trial.trial_id].unpause_trial(trial)
logger.info(f"Choosing trial {trial.config} to run from trialrunner.")
return trial
def debug_string(self) -> str:
out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
class _BracketV2(_Bracket):
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _BracketV2(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [
(min_t * self.rf ** (k + s), {}, []) for k in reversed(range(MAX_RUNGS))
]
def cutoff(self, recorded) -> Union[None, int, float, complex, np.ndarray]:
if len(recorded) < self.rf:
return None
return np.nanpercentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
def top_k_ids(self, recorded) -> List[str]:
entries = list(recorded.items())
k = int(len(entries) / self.rf)
top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
return [tid for tid, value in top_rung]
def on_result(
self,
trial: Trial,
cur_iter: int,
cur_rew: Optional[float],
complete: bool = False,
) -> str:
action = TrialScheduler.CONTINUE
if cur_rew is None:
logger.warning(
"Reward attribute is None! Consider"
" reporting using a different field."
)
return action
for milestone, recorded, paused in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
recorded[trial.trial_id] = cur_rew
top_k_trial_ids = self.top_k_ids(recorded)
if complete or trial.status != Trial.RUNNING:
break
if trial.trial_id not in top_k_trial_ids:
action = TrialScheduler.PAUSE
paused += [trial]
break
if action == TrialScheduler.PAUSE:
print(trial, cur_iter)
return action
def debug_str(self) -> str:
iters = " | ".join(
[
"Iter {:.3f}: {} [{} paused]".format(
milestone, self.cutoff(recorded), len(paused)
)
for milestone, recorded, paused in self._rungs
]
)
return "Bracket: " + iters
def promotable_trials(self) -> Generator[Trial, None, None]:
for _, recorded, paused in self._rungs:
for tid in self.top_k_ids(recorded):
paused_trials = {p.trial_id: p for p in paused}
if tid in paused_trials:
yield paused_trials[tid]
def unpause_trial(self, trial: Trial) -> None:
for _, _, paused in self._rungs:
if trial in paused:
paused.pop(paused.index(trial))
assert trial not in paused |
@Yard1, I ran an experiment with the following settings: num_samples = 350
metric = "accuracy"
mode = "max"
max_concorrent_trials = 1
search_alg = HyperOptSearch(
space=space, metric=metric, mode=mode, n_initial_points=20
)
search_alg = ConcurrencyLimiter(
search_alg, max_concurrent=max_concorrent_trials
)
scheduler = ASHAv2(
time_attr="training_iteration",
metric=metric,
mode=mode,
max_t=120,
grace_period=1,
reduction_factor=35,
) This ran the first experiment for one epoch, which is my Is there something I have to change in |
Removing the |
In the above experiment, 42 trials were paused at the end of their first epoch, but one suddenly CONTINUEd execution and is currently at epoch 18. Shouldn't the PENDING trial have been run instead? Update 1:
Update 2:
The following line of code from ASHAv2 is executed, which also shows the trial was promoted: logger.warning(f"Promoting trial [{trial.config}].") |
@Yard1, can we change the This should address the issue I explained in my previous message, but I am unsure if it will cause new problems. |
At some point in the experiment, the number of paused trials don't add up:
=> 111 paused trials combined.
|
Have you tried experimenting with that? |
I'm not very familiar with different aspects of Ray. If I do that, could it cause any problems in other parts of the code? |
It shouldn't if you just limit it to the ASHA scheduler |
Hi all, I made some modifications to the ASHAv2 implementation from above:
The only thing still different from the original ASHA paper is the stopping criterion. In the original ASHA paper, when there are no promotable trials, new trials are spawned, and the tuning run is stopped when a certain number of total trials are spawned (see the section "Using n as ASHA’s stopping criterion" from the ASHA paper, as well as Figure 3 in this blog). I decided not to add this for now as it was not needed for my use case, but I'm sure it could be added. Anyway, here's the code. I would happily make a PR for this, but I don't think the Ray team will want to have two different ASHA schedulers. Edit: I forgot to mention that if you want to use this scheduler with from ray.tune.execution import trial_runner
from ray.tune.execution.tune_controller import TuneController
from ray.tune.experiment import Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.async_hyperband import _Bracket
import logging
from typing import Dict, List, Optional, Union, Generator
import numpy as np
logger = logging.getLogger(__name__)
class _Rung:
"""Simple class to store each Rung of an ASHA bracket"""
def __init__(
self,
milestone: int
):
self.milestone = milestone
self.started = []
self.paused = []
self.recorded = {}
self.killed = False
def kill_paused_trials(self, tune_controller: "TuneController"):
for trial in self.paused:
if trial.status is not Trial.ERROR:
tune_controller.stop_trial(trial)
self.killed = True
def __str__ (self):
s = f"""
Milestone: {self.milestone}
Started: {self.started}
Paused: {self.paused}
Recorded: {self.recorded}
Killed: {self.killed}
"""
return s
class _BracketV2(_Bracket):
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _BracketV2(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int, stop_last_trials: bool = True):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [_Rung(min_t * self.rf ** (k + s)) for k in reversed(range(MAX_RUNGS))]
self._stop_last_trials = stop_last_trials
def cutoff(self, rung: _Rung) -> Union[None, int, float, complex, np.ndarray]:
if len(rung.recorded) < self.rf:
return None
return np.nanpercentile(list(rung.recorded.values()), (1 - 1 / self.rf) * 100)
def top_k_ids(self, rung: _Rung) -> List[str]:
entries = list(rung.recorded.items())
k = int(len(entries) / self.rf)
top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
return [tid for tid, value in top_rung]
def on_result(
self,
trial: Trial,
cur_iter: int,
cur_rew: Optional[float],
complete: bool = False,
) -> str:
action = TrialScheduler.CONTINUE
if cur_rew is None:
logger.warning(
"Reward attribute is None! Consider"
" reporting using a different field."
)
return action
for rung_id, rung in enumerate(self._rungs):
if (
cur_iter >= rung.milestone
and trial.trial_id in rung.recorded
and not self._stop_last_trials
):
# If our result has been recorded for this trial already, the
# decision to continue training has already been made. Thus we can
# skip new cutoff calculation and just continue training.
# We can also break as milestones are descending.
break
if cur_iter < rung.milestone or trial.trial_id in rung.recorded:
continue
else:
rung.recorded[trial.trial_id] = cur_rew
# top_k_trial_ids = self.top_k_ids(recorded)
if complete or trial.status != Trial.RUNNING:
break
# if trial.trial_id not in top_k_trial_ids:
if rung_id > 0:
# set_trace()
action = TrialScheduler.PAUSE
# action = TrialScheduler.STOP
rung.paused += [trial]
break
if action == TrialScheduler.PAUSE:
print(trial, cur_iter)
return action
def debug_str(self) -> str:
iters = " | ".join(
[
"Iter {:.3f}: {} [{} paused]".format(
rung.milestone, self.cutoff(rung), len(rung.paused)
)
for rung in self._rungs
]
)
return "Bracket: " + iters
def _promotable_trials_per_rung(self, rung: _Rung) -> Generator[Trial, None, None]:
for tid in self.top_k_ids(rung):
paused_trials = {p.trial_id: p for p in rung.paused}
if tid in paused_trials:
yield paused_trials[tid]
def promotable_trials(self) -> Generator[Trial, None, None]:
for rung in self._rungs:
for trial in self._promotable_trials_per_rung(rung):
yield trial
def unpause_trial(self, trial: Trial) -> None:
for i, rung in enumerate(self._rungs):
if trial in rung.paused:
rung.paused.pop(rung.paused.index(trial))
if i > 0:
prev_rung = self._rungs[i-1]
prev_rung.started += [trial]
else:
raise Exception("ATTEMPTING TO UNPAUSE TRIAL ", trial, " AT HIGHEST RUNG")
assert trial not in rung.paused
class ASHAv2(AsyncHyperBandScheduler):
"""Implements the Async Successive Halving with better termination.
This should provide similar theoretical performance as HyperBand but
avoid straggler issues that HyperBand faces. One implementation detail
is when using multiple brackets, trial allocation to bracket is done
randomly with over a softmax probability.
See https://arxiv.org/abs/1810.05934
Args:
time_attr (str): A training result attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
metric (str): The training result objective value attribute. Stopping
procedures will use this attribute. If None but a mode was passed,
the `ray.tune.result.DEFAULT_METRIC` will be used per default.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
max_t (float): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
reduction_factor (float): Used to set halving rate and amount. This
is simply a unit-less scalar.
brackets (int): Number of brackets. Each bracket has a different
halving rate, specified by the reduction factor.
"""
def __init__(
self,
time_attr: str = "training_iteration",
metric: Optional[str] = None,
mode: Optional[str] = None,
max_t: int = 100,
grace_period: int = 1,
reduction_factor: float = 4,
brackets: int = 1,
stop_last_trials: bool = True
):
super().__init__(
time_attr=time_attr,
metric=metric,
mode=mode,
max_t=max_t,
grace_period=grace_period,
reduction_factor=reduction_factor,
brackets=brackets,
stop_last_trials=stop_last_trials
)
# Tracks state for new trial add
self._brackets = [
_BracketV2(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._num_paused = 0
self._trial_info: Dict[str, _BracketV2] = {} # Stores Trial -> Bracket
def on_trial_result(
self, tune_controller: "TuneController", trial: Trial, result: Dict
) -> str:
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
if result[self._time_attr] >= self._max_t and self._stop_last_trials:
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
result[self._time_attr],
self._metric_op * result[self._metric],
)
if action == TrialScheduler.STOP:
self._num_stopped += 1
if action == TrialScheduler.PAUSE:
self._num_paused += 1
return action
def on_trial_complete(
self, tune_controller: "TuneController", trial: Trial, result: Dict
):
if self._time_attr not in result or self._metric not in result:
return
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
result[self._time_attr],
self._metric_op * result[self._metric],
complete=True,
)
del self._trial_info[trial.trial_id]
def _choose_new_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
for trial in tune_controller.get_trials():
for bracket in self._brackets:
if trial.status == Trial.PENDING and tune_controller.trial_executor.has_resources_for_trial(trial) and trial not in bracket._rungs[-1].started:
bracket._rungs[-1].started.append(trial)
print(f"CHOOSING NEW TRIAL {trial} TO RUN")
return trial
return None
def _choose_paused_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
# Note: this only chooses paused trials NOT paused because they
# reached the end of their rung, but paused for some other reason,
# e.g. from a ResourceChangingScheduler
for bracket in self._brackets:
for i, rung in enumerate(reversed(bracket._rungs)):
for trial in rung.started:
if trial.status == Trial.PAUSED and trial in rung.started and trial.trial_id not in rung.recorded:
# If the trial was started, is currently paused, and was never recorded (i.e. reached rung milestone)
# then it was paused for some other reason (e.g. a ResourceChangingScheduler), and can be resumed
print(f"UNPAUSING TRIAL {trial} FROM RUNG {rung.milestone}")
return trial
return None
def _choose_promotable_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
for bracket in self._brackets:
# Iterate rungs lowest to highest
# so that earlier trials promote faster
rungs_lowest_to_highest = list(reversed(bracket._rungs))
for rung_id, rung in enumerate(rungs_lowest_to_highest):
for trial in bracket._promotable_trials_per_rung(rung):
# print(f"HAS RESOURCES FOR PROMOTABLE TRIAL {trial}: {tune_controller.trial_executor.has_resources_for_trial(trial)}")
# print("ACTOR RESOURCES", tune_controller._actor_manager.get_live_actors_resources())
# live_resources = tune_controller._actor_manager.get_live_actors_resources()
# # used_cpu, total_cpu, used_gpu, total_gpu, _ = tune_controller._resource_updater._get_used_avail_resources(live_resources)
# used_cpu = live_resources.pop("CPU", 0)
# total_cpu = tune_controller._resource_updater._avail_resources.cpu
# used_gpu = live_resources.pop("GPU", 0)
# total_gpu = tune_controller._resource_updater._avail_resources.gpu
# avail_cpu, avail_gpu = total_cpu - used_cpu, total_gpu - used_gpu
# print(f"AVAILABLE RESOURCES: CPU {avail_cpu}, GPU {avail_gpu}")
# print("TRIAL RESOURCES", trial.placement_group_factory.required_resources)
if trial and tune_controller.trial_executor.has_resources_for_trial(trial) and trial.status == Trial.PAUSED:
assert trial.status == Trial.PAUSED
# logger.warning(f"Promoting trial {trial}.")
print(f"PROMOTING TRIAL {trial} from rung {rung.milestone}")
bracket.unpause_trial(trial)
return trial
if rung_id > 0:
prev_rung = rungs_lowest_to_highest[rung_id - 1]
if prev_rung.killed and \
not rung.killed and \
len(rung.started) == len(rung.recorded) and \
len(bracket.top_k_ids(rung)) == 0:
for trial in rung.paused:
bracket.unpause_trial(trial)
return trial
return None
def _try_terminate_finished_rungs(self, bracket: _BracketV2, tune_controller: "TuneController"):
# When there are no more trials that can possibly be
# promoted from a given rung, that rung can be "terminated", i.e.
# all paused trialed can be terminated.
# Note: such early stopping is only possible because no new trials are spawned,
# unlike the original ASHA paper which spawns a new trial in the lowest rung
# when no trials can be promoted, see Alg 2 of https://arxiv.org/pdf/1810.05934.pdf.
# Consequently, this ASHA will terminate when all trials have been promoted
# as far as possible, and all trials in the last rung finish, unlike the original ASHA
# which terminates once a provided number of trials are spawned.
rungs_lowest_to_highest = list(reversed(bracket._rungs))
for rung_id, rung in enumerate(rungs_lowest_to_highest):
if rung_id == 0:
if not rung.killed and len(rung.recorded) == len(tune_controller.get_trials()) and \
not any(bracket._promotable_trials_per_rung(rung)):
# If all trials have been recorded in lowest rung and
# there are not promotable trials, can kill all remaining trials in rung
print(f"KILLING RUNG {rung.milestone}")
rung.kill_paused_trials(tune_controller)
else:
prev_rung = rungs_lowest_to_highest[rung_id - 1]
if prev_rung.killed and \
not rung.killed and \
len(rung.started) == len(rung.recorded) and \
len(rung.started) > 0 and \
not any(bracket._promotable_trials_per_rung(rung)):
# If the previous rung has been terminated, and
# all started trials in current rung have finished, and
# the rung has no promotable trials, can be terminated
print(f"KILLING RUNG {rung.milestone}")
rung.kill_paused_trials(tune_controller)
def choose_trial_to_run(self, tune_controller: "TuneController") -> Trial:
# First try to start any new trials
trial = self._choose_new_trial_to_run(tune_controller)
# Next, try to start trials paused for some reason
# other than finishing a rung (e.g. a ResourceChangingScheduler)
if not trial:
trial = self._choose_paused_trial_to_run(tune_controller)
# Then, look for a promotable trial
if not trial:
trial = self._choose_promotable_trial_to_run(tune_controller)
# Check if any rungs can be terminated
for bracket in self._brackets:
self._try_terminate_finished_rungs(bracket, tune_controller)
return trial
def debug_string(self) -> str:
out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out |
Seemingly related: #44256 |
Description
Can you please add support to pause trials scheduled with the ASHA scheduler when they reach a milestone and allocate their resources to pending trials that have not reached the same milestone yet?
This feature request is related to #4401, where @mseeger has outlined a few differences between the ASHA paper and Tune's implementation.
Use case
Assume I would like to perform hyperparameter tuning on a single-node machine with eight GPUs, have set
num_samples=16
, and want to useTorchTrainer
withnum_workers=8
, which means each trial will be allocated all eight GPUs.Because the trials cannot run in parallel given the number and availability of resources, I expect all trials to be given a chance to reach the first milestone before ASHA decides which ones to stop and which ones to promote. However, in the current implementation, the first trial continues to run to the end without ever being evaluated at any milestone.
An alternative is to allocate one GPU per trial, which allows running eight trials in parallel. This utilizes all resources initially, but near the end of optimization, where most trials are pruned, some resources may be unused. One can use a
ResourceChangingScheduler
withTorchTrainer
to allocate the unused resources to running trials as discussed in this thread and shown in this example. However, changing the number of GPUs in a DDP experiment will change the effective batch size, which, in turn, may require non-trivial adjustments to the learning rate and other hyperparameters for best accuracy.As a result, I would like to allocate as many resources as possible to each trial and explore the hyperparameter space in that setting so that no further tuning of the hyperparameters will be needed down the line.
The text was updated successfully, but these errors were encountered: