Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Pause trials scheduled with the ASHA scheduler when they reach a milestone #32634

Open
MahdiNazemi opened this issue Feb 16, 2023 · 19 comments
Labels
enhancement Request for new feature and/or capability P2 Important issue, but not time-critical tune Tune-related issues

Comments

@MahdiNazemi
Copy link

Description

Can you please add support to pause trials scheduled with the ASHA scheduler when they reach a milestone and allocate their resources to pending trials that have not reached the same milestone yet?

This feature request is related to #4401, where @mseeger has outlined a few differences between the ASHA paper and Tune's implementation.

Use case

Assume I would like to perform hyperparameter tuning on a single-node machine with eight GPUs, have set num_samples=16, and want to use TorchTrainer with num_workers=8, which means each trial will be allocated all eight GPUs.

Because the trials cannot run in parallel given the number and availability of resources, I expect all trials to be given a chance to reach the first milestone before ASHA decides which ones to stop and which ones to promote. However, in the current implementation, the first trial continues to run to the end without ever being evaluated at any milestone.

An alternative is to allocate one GPU per trial, which allows running eight trials in parallel. This utilizes all resources initially, but near the end of optimization, where most trials are pruned, some resources may be unused. One can use a ResourceChangingScheduler with TorchTrainer to allocate the unused resources to running trials as discussed in this thread and shown in this example. However, changing the number of GPUs in a DDP experiment will change the effective batch size, which, in turn, may require non-trivial adjustments to the learning rate and other hyperparameters for best accuracy.

As a result, I would like to allocate as many resources as possible to each trial and explore the hyperparameter space in that setting so that no further tuning of the hyperparameters will be needed down the line.

@MahdiNazemi MahdiNazemi added enhancement Request for new feature and/or capability triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 16, 2023
@xwjiang2010
Copy link
Contributor

hmmm, is it necessarily better to use more GPUs towards the end of trials? To fully utilize these GPUs, one would increase global batch size which may impact model performance especially towards the end. As you mentioned, lr or other hyperparameters may need to be co-adjusted.
Is this request more about having Ray Tune "somehow" figure out how other parameters should be adjusted with larger batch size without user input?
cc @Yard1

@MahdiNazemi
Copy link
Author

This request is about adding functionality to preempt trials that have reached a milestone in favor of those which have yet to reach the same milestone.

@xwjiang2010
Copy link
Contributor

sure, but how would the remaining trials (that have not yet reached a certain milestone) leverage the additional GPUs that get allocated to them? Do those trials need to adjust other hyperparameters in light of a bigger batch size? What is your proposal here?

@xwjiang2010
Copy link
Contributor

some pseudo code would be helpful for me to understand the intended flow and the gap here.

@xwjiang2010 xwjiang2010 added the needs-repro-script Issue needs a runnable script to be reproduced label Feb 17, 2023
@Yard1
Copy link
Member

Yard1 commented Feb 17, 2023

Hey @MahdiNazemi, please find the code below:
https://gist.github.com/Yard1/86142a00104ea47ae087f1561368ade1

This is code is a bit old but it should still work fine. It should provide the same implementation as in the paper.

We may want to integrate that into Ray Tune if there's demand

@MahdiNazemi
Copy link
Author

MahdiNazemi commented Feb 17, 2023

sure, but how would the remaining trials (that have not yet reached a certain milestone) leverage the additional GPUs that get allocated to them? Do those trials need to adjust other hyperparameters in light of a bigger batch size? What is your proposal here?

@xwjiang2010, sorry if my initial explanation of the use case was not clear enough.

Here are the steps I have in mind:

  1. Define the search space assuming a trial will use all GPUs on a node. For example, if you know what the learning rate range should be when a model is trained on a single GPU and want to scale the learning rate linearly with the number of GPUs, you can adjust the learning rate range in your search space accordingly.
  2. Start the first trial while allocating all GPUs to it.
  3. Preempt the first trial when it reaches the first ASHA milestone.
  4. Repeat steps 2. and 3. for the remaining samples. At this point, all samples have used all GPUs, so the resource utilization is maximized. Additionally, we have a global view of how all experiments performed at the first milestone.
  5. Given the performance of trials at the end of step 4., prune the underperforming trials.
  6. Continue training the promoted trials. Repeat steps 2., 3., and 4.

Because each trial is allocated all resources on a node, the GPUs will be fully utilized regardless of whether we are at the beginning or near the end of our search.

Here is how ASHA can be set to work now:

Case 1: use all GPUs per trial
In this case, we can use the search space according to 1., but the first trial will not be preempted at step 3., so it will continue for, say, 100 epochs before it is terminated. This means that this trial was never evaluated at a milestone and continues to the end even if it is underperforming.

Case 2: use a non-trivial subset of GPUs per trial
I see two possible issues with this approach:

  1. If a ResourceChangingScheduler is not used, the GPUs will be underutilized near the end of the search process.
  2. If a ResourceChangingScheduler is used, the global batch size will change as trials are allocated more resources, which may require making non-trivial changes to the hyperparameters.

@MahdiNazemi
Copy link
Author

MahdiNazemi commented Feb 17, 2023

Hey @MahdiNazemi, please find the code below: https://gist.github.com/Yard1/86142a00104ea47ae087f1561368ade1

This is code is a bit old but it should still work fine. It should provide the same implementation as in the paper.

We may want to integrate that into Ray Tune if there's demand

@Yard1, thanks for sharing the ASHAv2 implementation. I integrated the code into the schedulers directory and made the necessary changes to its __init__.py.

The following line prints the exact same output to stdout about every second:

logger.info(f"Choosing trial {trial.config} to run from trialrunner.")

Does it mean choose_trial_to_run() is being called repeatedly with the same trial.config`?

@Yard1
Copy link
Member

Yard1 commented Feb 17, 2023

I believe that means that it wants to run the trial but it cannot due to eg. all resources being claimed. You should be able to simply remove that code line.

@MahdiNazemi
Copy link
Author

The aforesaid line of code was printing the configuration of the first running trial over and over again. I was able to suppress the messages by using a ConcurrencyLimiter.

Here is an updated version of the code you shared with some import statements changed to avoid warning messages:

import logging
from typing import Dict, List, Optional, Union, Generator

import numpy as np

from ray.tune.execution import trial_runner
from ray.tune.experiment import Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.async_hyperband import _Bracket


logger = logging.getLogger(__name__)


class ASHAv2(AsyncHyperBandScheduler):
    """Implements the Async Successive Halving with better termination.
    This should provide similar theoretical performance as HyperBand but
    avoid straggler issues that HyperBand faces. One implementation detail
    is when using multiple brackets, trial allocation to bracket is done
    randomly with over a softmax probability.
    See https://arxiv.org/abs/1810.05934
    Args:
        time_attr (str): A training result attr to use for comparing time.
            Note that you can pass in something non-temporal such as
            `training_iteration` as a measure of progress, the only requirement
            is that the attribute should increase monotonically.
        metric (str): The training result objective value attribute. Stopping
            procedures will use this attribute. If None but a mode was passed,
            the `ray.tune.result.DEFAULT_METRIC` will be used per default.
        mode (str): One of {min, max}. Determines whether objective is
            minimizing or maximizing the metric attribute.
        max_t (float): max time units per trial. Trials will be stopped after
            max_t time units (determined by time_attr) have passed.
        grace_period (float): Only stop trials at least this old in time.
            The units are the same as the attribute named by `time_attr`.
        reduction_factor (float): Used to set halving rate and amount. This
            is simply a unit-less scalar.
        brackets (int): Number of brackets. Each bracket has a different
            halving rate, specified by the reduction factor.
    """

    def __init__(
        self,
        time_attr: str = "training_iteration",
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        max_t: int = 100,
        grace_period: int = 1,
        reduction_factor: float = 4,
        brackets: int = 1,
    ):
        super().__init__(
            time_attr=time_attr,
            metric=metric,
            mode=mode,
            max_t=max_t,
            grace_period=grace_period,
            reduction_factor=reduction_factor,
            brackets=brackets,
        )

        # Tracks state for new trial add
        self._brackets = [
            _BracketV2(grace_period, max_t, reduction_factor, s)
            for s in range(brackets)
        ]
        self._num_paused = 0
        self._trial_info: Dict[str, _BracketV2] = {}  # Stores Trial -> Bracket

    def on_trial_result(
        self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict
    ) -> str:
        action = TrialScheduler.CONTINUE
        if self._time_attr not in result or self._metric not in result:
            return action
        if result[self._time_attr] >= self._max_t:
            action = TrialScheduler.STOP
        else:
            bracket = self._trial_info[trial.trial_id]
            action = bracket.on_result(
                trial,
                result[self._time_attr],
                self._metric_op * result[self._metric],
            )
        if action == TrialScheduler.STOP:
            self._num_stopped += 1
        if action == TrialScheduler.PAUSE:
            self._num_paused += 1
        return action

    def on_trial_complete(
        self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict
    ):
        if self._time_attr not in result or self._metric not in result:
            return
        bracket = self._trial_info[trial.trial_id]
        bracket.on_result(
            trial,
            result[self._time_attr],
            self._metric_op * result[self._metric],
            complete=True,
        )
        del self._trial_info[trial.trial_id]

    def choose_trial_to_run(self, trial_runner: "trial_runner.TrialRunner") -> Trial:
        for bracket in self._brackets:
            for trial in bracket.promotable_trials():
                if trial and trial_runner.trial_executor.has_resources_for_trial(trial):
                    assert trial.status == Trial.PAUSED
                    logger.warning(f"Promoting trial [{trial.config}].")
                    bracket.unpause_trial(trial)
                    return trial
        trial = FIFOScheduler.choose_trial_to_run(self, trial_runner)
        if trial:
            self._trial_info[trial.trial_id].unpause_trial(trial)
            logger.info(f"Choosing trial {trial.config} to run from trialrunner.")
        return trial

    def debug_string(self) -> str:
        out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
        out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
        return out


class _BracketV2(_Bracket):
    """Bookkeeping system to track the cutoffs.
    Rungs are created in reversed order so that we can more easily find
    the correct rung corresponding to the current iteration of the result.
    Example:
        >>> b = _BracketV2(1, 10, 2, 3)
        >>> b.on_result(trial1, 1, 2)  # CONTINUE
        >>> b.on_result(trial2, 1, 4)  # CONTINUE
        >>> b.cutoff(b._rungs[-1][1]) == 3.0  # rungs are reversed
        >>> b.on_result(trial3, 1, 1)  # STOP
        >>> b.cutoff(b._rungs[0][1]) == 2.0
    """

    def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int):
        self.rf = reduction_factor
        MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
        self._rungs = [
            (min_t * self.rf ** (k + s), {}, []) for k in reversed(range(MAX_RUNGS))
        ]

    def cutoff(self, recorded) -> Union[None, int, float, complex, np.ndarray]:
        if len(recorded) < self.rf:
            return None
        return np.nanpercentile(list(recorded.values()), (1 - 1 / self.rf) * 100)

    def top_k_ids(self, recorded) -> List[str]:
        entries = list(recorded.items())
        k = int(len(entries) / self.rf)
        top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
        return [tid for tid, value in top_rung]

    def on_result(
        self,
        trial: Trial,
        cur_iter: int,
        cur_rew: Optional[float],
        complete: bool = False,
    ) -> str:
        action = TrialScheduler.CONTINUE
        if cur_rew is None:
            logger.warning(
                "Reward attribute is None! Consider"
                " reporting using a different field."
            )
            return action
        for milestone, recorded, paused in self._rungs:
            if cur_iter < milestone or trial.trial_id in recorded:
                continue
            else:
                recorded[trial.trial_id] = cur_rew
                top_k_trial_ids = self.top_k_ids(recorded)
                if complete or trial.status != Trial.RUNNING:
                    break
                if trial.trial_id not in top_k_trial_ids:
                    action = TrialScheduler.PAUSE
                    paused += [trial]
                break
        if action == TrialScheduler.PAUSE:
            print(trial, cur_iter)
        return action

    def debug_str(self) -> str:
        iters = " | ".join(
            [
                "Iter {:.3f}: {} [{} paused]".format(
                    milestone, self.cutoff(recorded), len(paused)
                )
                for milestone, recorded, paused in self._rungs
            ]
        )
        return "Bracket: " + iters

    def promotable_trials(self) -> Generator[Trial, None, None]:
        for _, recorded, paused in self._rungs:
            for tid in self.top_k_ids(recorded):
                paused_trials = {p.trial_id: p for p in paused}
                if tid in paused_trials:
                    yield paused_trials[tid]

    def unpause_trial(self, trial: Trial) -> None:
        for _, _, paused in self._rungs:
            if trial in paused:
                paused.pop(paused.index(trial))
            assert trial not in paused

@MahdiNazemi
Copy link
Author

MahdiNazemi commented Feb 18, 2023

@Yard1, I ran an experiment with the following settings:

num_samples = 350

metric = "accuracy"
mode = "max"

max_concorrent_trials = 1


search_alg = HyperOptSearch(
    space=space, metric=metric, mode=mode, n_initial_points=20
)
search_alg = ConcurrencyLimiter(
    search_alg, max_concurrent=max_concorrent_trials
)

scheduler = ASHAv2(
    time_attr="training_iteration",
    metric=metric,
    mode=mode,
    max_t=120,
    grace_period=1,
    reduction_factor=35,
)

This ran the first experiment for one epoch, which is my grace_period, but immediately continued the same experiment from a saved checkpoint instead of initiating a new trial.

Is there something I have to change in ASHAv2, my search_alg, or scheduler?

@MahdiNazemi
Copy link
Author

Removing the ConcurrencyLimiter resolves the issue.

@MahdiNazemi
Copy link
Author

MahdiNazemi commented Feb 22, 2023

@Yard1, I ran an experiment with the following settings:

num_samples = 350

metric = "accuracy"
mode = "max"

search_alg = HyperOptSearch(
    space=space, metric=metric, mode=mode, n_initial_points=20
)

scheduler = ASHAv2(
    time_attr="training_iteration",
    metric=metric,
    mode=mode,
    max_t=120,
    grace_period=1,
    reduction_factor=35,
)

In the above experiment, 42 trials were paused at the end of their first epoch, but one suddenly CONTINUEd execution and is currently at epoch 18. Shouldn't the PENDING trial have been run instead?

Update 1:
The trial was PAUSED at the end of epoch 35.

Bracket: Iter 35.000: None [1 paused] | Iter 1.000: 7.7419998168945305 [41 paused]

Update 2:
Another trial was promoted at the following stage of the experiment:

Bracket: Iter 35.000: None [1 paused] | Iter 1.000: 7.616686030796596 [102 paused]

The following line of code from ASHAv2 is executed, which also shows the trial was promoted:

logger.warning(f"Promoting trial [{trial.config}].")

@MahdiNazemi MahdiNazemi reopened this Feb 22, 2023
@MahdiNazemi
Copy link
Author

MahdiNazemi commented Feb 22, 2023

@Yard1, can we change the choose_trial_to_run method to only return a promotable trial if there are no pending trials in the FIFO scheduler?

This should address the issue I explained in my previous message, but I am unsure if it will cause new problems.

@MahdiNazemi
Copy link
Author

At some point in the experiment, the number of paused trials don't add up:

Bracket: Iter 35.000: None [1 paused] | Iter 1.000: 7.5924002647399895 [110 paused]

=> 111 paused trials combined.

Number of trials: 115/350 (113 PAUSED, 1 PENDING, 1 RUNNING)

@Yard1
Copy link
Member

Yard1 commented Feb 22, 2023

@Yard1, can we change the choose_trial_to_run method to only return a promotable trial if there are no pending trials in the FIFO scheduler?

This should address the issue I explained in my previous message, but I am unsure if it will cause new problems.

Have you tried experimenting with that?

@MahdiNazemi
Copy link
Author

I'm not very familiar with different aspects of Ray. If I do that, could it cause any problems in other parts of the code?

@Yard1
Copy link
Member

Yard1 commented Feb 23, 2023

It shouldn't if you just limit it to the ASHA scheduler

@justinvyu justinvyu added tune Tune-related issues P2 Important issue, but not time-critical and removed needs-repro-script Issue needs a runnable script to be reproduced triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 24, 2023
@mvanness354
Copy link

mvanness354 commented Aug 22, 2023

Hi all, I made some modifications to the ASHAv2 implementation from above:

  • Trials are promoted from lower rungs before higher rungs. By extension, trials that have not started yet in any rungs are started before any actual promotions (as suggested by @MahdiNazemi).
  • Rungs are "killed", i.e. all paused trials are terminated, if the remaining paused trials have no chance of promoting. The logic for when rungs can be killed is in comments in the updated code. This makes it so that when all trials in the last rung as terminated (because they have reached max_t), the whole tune job is terminated, assuming there are no more possible promotions. This is also nice for using with a ResourceChangingScheduler, so the scheduler knows correctly when to allocate new resources.
  • The stop_last_trials parameter from the base ASHA is added back in to allow trials to run past max_t if desired.
  • Some additional logic was added so that pairing with ResourceChangingScheduler works correctly.

The only thing still different from the original ASHA paper is the stopping criterion. In the original ASHA paper, when there are no promotable trials, new trials are spawned, and the tuning run is stopped when a certain number of total trials are spawned (see the section "Using n as ASHA’s stopping criterion" from the ASHA paper, as well as Figure 3 in this blog). I decided not to add this for now as it was not needed for my use case, but I'm sure it could be added.

Anyway, here's the code. I would happily make a PR for this, but I don't think the Ray team will want to have two different ASHA schedulers.

Edit: I forgot to mention that if you want to use this scheduler with XGBoostTrainer, there is a bug that prevents checkpointing from working currently and thus the paused trials are actually restarted from iteration 0 when resumed. Luckily, there is a workaround by commenting out this if statement. I will link an issue for this when I get around to creating one.

from ray.tune.execution import trial_runner
from ray.tune.execution.tune_controller import TuneController
from ray.tune.experiment import Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.async_hyperband import _Bracket
import logging
from typing import Dict, List, Optional, Union, Generator
import numpy as np


logger = logging.getLogger(__name__)

class _Rung:
    """Simple class to store each Rung of an ASHA bracket"""

    def __init__(
        self,
        milestone: int
    ):
        self.milestone = milestone
        self.started = []
        self.paused = []
        self.recorded = {}
        self.killed = False

    def kill_paused_trials(self, tune_controller: "TuneController"):
        for trial in self.paused:
            if trial.status is not Trial.ERROR:
                tune_controller.stop_trial(trial)
        self.killed = True

    def __str__ (self):
        s = f"""
        Milestone: {self.milestone}
        Started: {self.started}
        Paused: {self.paused}
        Recorded: {self.recorded}
        Killed: {self.killed}
        """
        return s


class _BracketV2(_Bracket):
    """Bookkeeping system to track the cutoffs.
    Rungs are created in reversed order so that we can more easily find
    the correct rung corresponding to the current iteration of the result.
    Example:
        >>> b = _BracketV2(1, 10, 2, 3)
        >>> b.on_result(trial1, 1, 2)  # CONTINUE
        >>> b.on_result(trial2, 1, 4)  # CONTINUE
        >>> b.cutoff(b._rungs[-1][1]) == 3.0  # rungs are reversed
        >>> b.on_result(trial3, 1, 1)  # STOP
        >>> b.cutoff(b._rungs[0][1]) == 2.0
    """

    def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int, stop_last_trials: bool = True):
        self.rf = reduction_factor
        MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
        self._rungs = [_Rung(min_t * self.rf ** (k + s)) for k in reversed(range(MAX_RUNGS))]
        self._stop_last_trials = stop_last_trials

    def cutoff(self, rung: _Rung) -> Union[None, int, float, complex, np.ndarray]:
        if len(rung.recorded) < self.rf:
            return None
        return np.nanpercentile(list(rung.recorded.values()), (1 - 1 / self.rf) * 100)

    def top_k_ids(self, rung: _Rung) -> List[str]:
        entries = list(rung.recorded.items())
        k = int(len(entries) / self.rf)
        top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
        return [tid for tid, value in top_rung]

    def on_result(
        self,
        trial: Trial,
        cur_iter: int,
        cur_rew: Optional[float],
        complete: bool = False,
    ) -> str:

        action = TrialScheduler.CONTINUE
        if cur_rew is None:
            logger.warning(
                "Reward attribute is None! Consider"
                " reporting using a different field."
            )
            return action
        for rung_id, rung in enumerate(self._rungs):

            if (
                cur_iter >= rung.milestone
                and trial.trial_id in rung.recorded
                and not self._stop_last_trials
            ):
                # If our result has been recorded for this trial already, the
                # decision to continue training has already been made. Thus we can
                # skip new cutoff calculation and just continue training.
                # We can also break as milestones are descending.
                break
            if cur_iter < rung.milestone or trial.trial_id in rung.recorded:
                continue
            else:
                rung.recorded[trial.trial_id] = cur_rew
                # top_k_trial_ids = self.top_k_ids(recorded)
                if complete or trial.status != Trial.RUNNING:
                    break
                # if trial.trial_id not in top_k_trial_ids:

                if rung_id > 0:
                    # set_trace()
                    action = TrialScheduler.PAUSE
                        # action = TrialScheduler.STOP
                    rung.paused += [trial]
                break
        if action == TrialScheduler.PAUSE:
            print(trial, cur_iter)
        return action

    def debug_str(self) -> str:
        iters = " | ".join(
            [
                "Iter {:.3f}: {} [{} paused]".format(
                    rung.milestone, self.cutoff(rung), len(rung.paused)
                )
                for rung in self._rungs
            ]
        )
        return "Bracket: " + iters

    def _promotable_trials_per_rung(self, rung: _Rung) -> Generator[Trial, None, None]:
        for tid in self.top_k_ids(rung):
            paused_trials = {p.trial_id: p for p in rung.paused}
            if tid in paused_trials:
                yield paused_trials[tid]

    def promotable_trials(self) -> Generator[Trial, None, None]:
        for rung in self._rungs:
            for trial in self._promotable_trials_per_rung(rung):
                yield trial

    def unpause_trial(self, trial: Trial) -> None:
        for i, rung in enumerate(self._rungs):
            if trial in rung.paused:
                rung.paused.pop(rung.paused.index(trial))
                if i > 0:
                    prev_rung = self._rungs[i-1]
                    prev_rung.started += [trial]
                else:
                    raise Exception("ATTEMPTING TO UNPAUSE TRIAL ", trial, " AT HIGHEST RUNG")

            assert trial not in rung.paused


class ASHAv2(AsyncHyperBandScheduler):
    """Implements the Async Successive Halving with better termination.
    This should provide similar theoretical performance as HyperBand but
    avoid straggler issues that HyperBand faces. One implementation detail
    is when using multiple brackets, trial allocation to bracket is done
    randomly with over a softmax probability.
    See https://arxiv.org/abs/1810.05934
    Args:
        time_attr (str): A training result attr to use for comparing time.
            Note that you can pass in something non-temporal such as
            `training_iteration` as a measure of progress, the only requirement
            is that the attribute should increase monotonically.
        metric (str): The training result objective value attribute. Stopping
            procedures will use this attribute. If None but a mode was passed,
            the `ray.tune.result.DEFAULT_METRIC` will be used per default.
        mode (str): One of {min, max}. Determines whether objective is
            minimizing or maximizing the metric attribute.
        max_t (float): max time units per trial. Trials will be stopped after
            max_t time units (determined by time_attr) have passed.
        grace_period (float): Only stop trials at least this old in time.
            The units are the same as the attribute named by `time_attr`.
        reduction_factor (float): Used to set halving rate and amount. This
            is simply a unit-less scalar.
        brackets (int): Number of brackets. Each bracket has a different
            halving rate, specified by the reduction factor.
    """

    def __init__(
        self,
        time_attr: str = "training_iteration",
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        max_t: int = 100,
        grace_period: int = 1,
        reduction_factor: float = 4,
        brackets: int = 1,
        stop_last_trials: bool = True
    ):
        super().__init__(
            time_attr=time_attr,
            metric=metric,
            mode=mode,
            max_t=max_t,
            grace_period=grace_period,
            reduction_factor=reduction_factor,
            brackets=brackets,
            stop_last_trials=stop_last_trials
        )

        # Tracks state for new trial add
        self._brackets = [
            _BracketV2(grace_period, max_t, reduction_factor, s)
            for s in range(brackets)
        ]
        self._num_paused = 0
        self._trial_info: Dict[str, _BracketV2] = {}  # Stores Trial -> Bracket

    def on_trial_result(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ) -> str:
        action = TrialScheduler.CONTINUE
        if self._time_attr not in result or self._metric not in result:
            return action
        if result[self._time_attr] >= self._max_t and self._stop_last_trials:
            action = TrialScheduler.STOP
        else:
            bracket = self._trial_info[trial.trial_id]
            action = bracket.on_result(
                trial,
                result[self._time_attr],
                self._metric_op * result[self._metric],
            )
        if action == TrialScheduler.STOP:
            self._num_stopped += 1
        if action == TrialScheduler.PAUSE:
            self._num_paused += 1
        return action

    def on_trial_complete(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ):
        if self._time_attr not in result or self._metric not in result:
            return
        bracket = self._trial_info[trial.trial_id]
        bracket.on_result(
            trial,
            result[self._time_attr],
            self._metric_op * result[self._metric],
            complete=True,
        )
        del self._trial_info[trial.trial_id]

    def _choose_new_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
        for trial in tune_controller.get_trials():
            for bracket in self._brackets:
                if trial.status == Trial.PENDING and tune_controller.trial_executor.has_resources_for_trial(trial) and trial not in bracket._rungs[-1].started:
                    bracket._rungs[-1].started.append(trial)
                    print(f"CHOOSING NEW TRIAL {trial} TO RUN")
                    return trial
        return None

    def _choose_paused_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
        # Note: this only chooses paused trials NOT paused because they 
        # reached the end of their rung, but paused for some other reason,
        # e.g. from a ResourceChangingScheduler
        for bracket in self._brackets:
            for i, rung in enumerate(reversed(bracket._rungs)):
                for trial in rung.started:
                    if trial.status == Trial.PAUSED and trial in rung.started and trial.trial_id not in rung.recorded:
                        # If the trial was started, is currently paused, and was never recorded (i.e. reached rung milestone)
                        # then it was paused for some other reason (e.g. a ResourceChangingScheduler), and can be resumed
                        print(f"UNPAUSING TRIAL {trial} FROM RUNG {rung.milestone}")
                        return trial

        return None

    def _choose_promotable_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
        for bracket in self._brackets:
            # Iterate rungs lowest to highest
            # so that earlier trials promote faster
            rungs_lowest_to_highest = list(reversed(bracket._rungs))
            for rung_id, rung in enumerate(rungs_lowest_to_highest):
                for trial in bracket._promotable_trials_per_rung(rung):
                    # print(f"HAS RESOURCES FOR PROMOTABLE TRIAL {trial}: {tune_controller.trial_executor.has_resources_for_trial(trial)}")
                    # print("ACTOR RESOURCES", tune_controller._actor_manager.get_live_actors_resources())
                    # live_resources = tune_controller._actor_manager.get_live_actors_resources()
                    # # used_cpu, total_cpu, used_gpu, total_gpu, _ = tune_controller._resource_updater._get_used_avail_resources(live_resources)

                    # used_cpu = live_resources.pop("CPU", 0)
                    # total_cpu = tune_controller._resource_updater._avail_resources.cpu
                    # used_gpu = live_resources.pop("GPU", 0)
                    # total_gpu = tune_controller._resource_updater._avail_resources.gpu
                    
                    # avail_cpu, avail_gpu = total_cpu - used_cpu, total_gpu - used_gpu
                    # print(f"AVAILABLE RESOURCES: CPU {avail_cpu}, GPU {avail_gpu}")
                    # print("TRIAL RESOURCES", trial.placement_group_factory.required_resources)
                    if trial and tune_controller.trial_executor.has_resources_for_trial(trial) and trial.status == Trial.PAUSED:
                        assert trial.status == Trial.PAUSED
                        # logger.warning(f"Promoting trial {trial}.")
                        print(f"PROMOTING TRIAL {trial} from rung {rung.milestone}")
                        bracket.unpause_trial(trial)
                        return trial

                if rung_id > 0:
                    prev_rung = rungs_lowest_to_highest[rung_id - 1]
                    if prev_rung.killed and \
                    not rung.killed and \
                    len(rung.started) == len(rung.recorded) and \
                    len(bracket.top_k_ids(rung)) == 0:
                        for trial in rung.paused:
                            bracket.unpause_trial(trial)
                            return trial
                    

        return None
    


    def _try_terminate_finished_rungs(self, bracket: _BracketV2, tune_controller: "TuneController"):
        # When there are no more trials that can possibly be
        # promoted from a given rung, that rung can be "terminated", i.e.
        # all paused trialed can be terminated.
        
        # Note: such early stopping is only possible because no new trials are spawned, 
        # unlike the original ASHA paper which spawns a new trial in the lowest rung
        # when no trials can be promoted, see Alg 2 of https://arxiv.org/pdf/1810.05934.pdf.
        # Consequently, this ASHA will terminate when all trials have been promoted
        # as far as possible, and all trials in the last rung finish, unlike the original ASHA
        # which terminates once a provided number of trials are spawned.

        
        rungs_lowest_to_highest = list(reversed(bracket._rungs))
        for rung_id, rung in enumerate(rungs_lowest_to_highest):
            if rung_id == 0:
                if not rung.killed and len(rung.recorded) == len(tune_controller.get_trials()) and \
                not any(bracket._promotable_trials_per_rung(rung)):
                    # If all trials have been recorded in lowest rung and
                    # there are not promotable trials, can kill all remaining trials in rung
                    print(f"KILLING RUNG {rung.milestone}")
                    rung.kill_paused_trials(tune_controller)

            else:
                prev_rung = rungs_lowest_to_highest[rung_id - 1]
                if prev_rung.killed and \
                not rung.killed and \
                len(rung.started) == len(rung.recorded) and \
                len(rung.started) > 0 and \
                not any(bracket._promotable_trials_per_rung(rung)):
                    # If the previous rung has been terminated, and
                    # all started trials in current rung have finished, and
                    # the rung has no promotable trials, can be terminated
                    print(f"KILLING RUNG {rung.milestone}")
                    rung.kill_paused_trials(tune_controller)

    def choose_trial_to_run(self, tune_controller: "TuneController") -> Trial:

        # First try to start any new trials
        trial = self._choose_new_trial_to_run(tune_controller)

        # Next, try to start trials paused for some reason
        # other than finishing a rung (e.g. a ResourceChangingScheduler)
        if not trial:
            trial = self._choose_paused_trial_to_run(tune_controller)

        # Then, look for a promotable trial
        if not trial:
            trial = self._choose_promotable_trial_to_run(tune_controller)

        # Check if any rungs can be terminated
        for bracket in self._brackets:
            self._try_terminate_finished_rungs(bracket, tune_controller)

        return trial

    def debug_string(self) -> str:
        out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
        out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
        return out

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Mar 23, 2024

Seemingly related: #44256

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for new feature and/or capability P2 Important issue, but not time-critical tune Tune-related issues
Projects
None yet
Development

No branches or pull requests

6 participants