From 180c3b99dc34f8764786a041b27f61b2eff9bc3d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:41:35 +0100 Subject: [PATCH 01/81] [tune/train] Consolidate checkpoint manager 1: Common checkpoint manager class --- .../ray/util/ml_utils/checkpoint_manager.py | 378 ++++++++++++++++++ 1 file changed, 378 insertions(+) create mode 100644 python/ray/util/ml_utils/checkpoint_manager.py diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py new file mode 100644 index 000000000000..f97068b58377 --- /dev/null +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -0,0 +1,378 @@ +import gc +import heapq +import logging +import numbers +import os +import shutil + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Dict, Union, Callable, Tuple, List, Any + +import ray +from ray.tune.result import NODE_IP +from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.util import is_nan + +MAX = "max" +MIN = "min" + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class TrackedCheckpoint: + """Checkpoint tracked by a checkpoint manager. + + This class is used to track checkpoints generated by trainables and trainers in + order to add metadata (e.g. the result, or the node where it has been created) + and for bookkeeping purposes. + + Args: + dir_or_data: Checkpoint directory, checkpoint data, or a future to either. + storage_mode: Either MEMORY or PERSISTENT. + checkpoint_id: Checkpoint number. Usually this should be monotonically + increasing for each tracked checkpoint. + result: Observed metrics for this checkpoint. This is used to determine + the value of the ``checkpoint_score_attr``. + node_ip: IP of the node where the checkpoint was generated. Defaults + to the current node. + """ + + MEMORY = "memory" + PERSISTENT = "persistent" + + def __init__( + self, + dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], + storage_mode: str, + checkpoint_id: Optional[int] = None, + result: Optional[Dict] = None, + node_ip: Optional[str] = None, + ): + self.dir_or_data = dir_or_data + self.id = checkpoint_id + self.storage_mode = storage_mode + + # Todo: What to do if result is a subset of dir_or_data (dict) + self.result = result or {} + self.node_ip = node_ip or self.result.get(NODE_IP, None) + + def commit(self, path: Optional[Path] = None) -> None: + """Commit checkpoint to disk, if needed. + + Args: + path: Path to commit checkpoint to. + """ + pass + + def delete( + self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None + ) -> None: + """Delete checkpoint from disk, if needed. + + Args: + delete_fn: Function to be called with the tracked checkpoint as an + argument. Defaults to removing the local directory/file. + """ + delete_fn = delete_fn or _default_delete_fn + try: + delete_fn(self) + except Exception as e: + logger.warning(f"Checkpoint deletion failed: {e}") + + def __repr__(self): + if self.storage_mode == TrackedCheckpoint.MEMORY: + return f"" + + return ( + f"" + ) + + +def _default_delete_fn(checkpoint: TrackedCheckpoint): + if checkpoint.storage_mode != TrackedCheckpoint.PERSISTENT: + return + + if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): + if os.path.isfile(checkpoint.dir_or_data): + os.remove(checkpoint.dir_or_data) + return + elif os.path.isdir(checkpoint.dir_or_data): + shutil.rmtree(checkpoint.dir_or_data) + return + raise RuntimeError( + f"Could not delete checkpoint {checkpoint} from disk as it is " + f"neither file not directory. Path: {checkpoint.dir_or_data}." + ) + + +class _HeapCheckpointWrapper: + def __init__(self, priority: Any, tracked_checkpoint: TrackedCheckpoint): + self.priority = priority + self.tracked_checkpoint = tracked_checkpoint + + def __lt__(self, other): + return self.priority < other.priority + + def __repr__(self): + return f"_HeapCheckpoint({repr(self.tracked_checkpoint)})" + + +@PublicAPI(stability="beta") +@dataclass +class CheckpointStrategy: + """Configurable parameters for defining the checkpointing strategy. + + Default behavior is to persist all checkpoints to disk. If + ``num_to_keep`` is set, the default retention policy is to keep the + checkpoints with maximum timestamp, i.e. the most recent checkpoints. + + Args: + num_to_keep (Optional[int]): The number of checkpoints to keep + on disk for this run. If a checkpoint is persisted to disk after + there are already this many checkpoints, then an existing + checkpoint will be deleted. If this is ``None`` then checkpoints + will not be deleted. If this is ``0`` then no checkpoints will be + persisted to disk. + checkpoint_score_attribute (str): The attribute that will be used to + score checkpoints to determine which checkpoints should be kept + on disk when there are greater than ``num_to_keep`` checkpoints. + This attribute must be a key from the checkpoint + dictionary which has a numerical value. Per default, the last + checkpoints will be kept. + checkpoint_score_order (str). Either "max" or "min". + If "max", then checkpoints with highest values of + ``checkpoint_score_attribute`` will be kept. + If "min", then checkpoints with lowest values of + ``checkpoint_score_attribute`` will be kept. + """ + + num_to_keep: Optional[int] = None + checkpoint_score_attribute: Optional[str] = None + checkpoint_score_order: str = MAX + + def __post_init__(self): + if self.num_to_keep is not None and self.num_to_keep < 0: + raise ValueError( + f"Received invalid num_to_keep: " + f"{self.num_to_keep}. " + f"Must be None or non-negative integer." + ) + if self.checkpoint_score_order not in (MAX, MIN): + raise ValueError( + f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' + ) + + +class CheckpointManager: + """Common checkpoint management and bookkeeping class for Ray Train and Tune. + + This class acts as the common core for checkpoint bookkeeping in Ray ML libraries. + On a high level, this manager keeps a reference to all stored checkpoints + (both in-memory and on-disk checkpoints). For on-disk checkpoints, it + keeps a configured number of checkpoints according to specified metrics. + + The manager supports lazy data writing by utilizing the + ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + should be persisted to disk. + """ + + def __init__( + self, + checkpoint_strategy: CheckpointStrategy, + latest_checkpoint_id: int = 0, + delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, + ): + self._checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + + # Incremental unique checkpoint ID of this run. + self._latest_checkpoint_id = latest_checkpoint_id + + # Used for keeping top K checkpoints. + self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] + + # Best checkpoint altogether. + # Used for exposing best_checkpoint_path. + self._best_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_memory_checkpoint: Optional[TrackedCheckpoint] = None + + # Checkpoints that are not immediately removed + self._checkpoints_to_clean_up = set() + self._delete_fn = delete_fn + + def set_delete_fn(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]]): + """Update the function called to delete persisted checkpoints. + + Args: + delete_fn: Function that takes a tracked checkpoint as an argument and + deletes it from disk. + """ + self._delete_fn = delete_fn + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + """Register new checkpoint and add to bookkeeping. + + This method will register a new checkpoint and add it to the internal + bookkeeping logic. This means the checkpoint manager will decide if + this checkpoint should be kept, and if older or worse performing + checkpoints should be deleted. + + Subclasses have to implement this method. + + Args: + checkpoint: Tracked checkpoint object to add to bookkeeping. + """ + raise NotImplementedError + + def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): + assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY + self._latest_memory_checkpoint = memory_checkpoint + # Avoid memory leaks on k8s pods + gc.collect() + + def _replace_latest_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint + self._latest_persisted_checkpoint = persisted_checkpoint + + if self._checkpoint_strategy.num_to_keep == 0: + self._maybe_delete_persisted_checkpoint( + second_to_latest_persisted_checkpoint + ) + + def _maybe_replace_best_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + if self._best_persisted_checkpoint is None: + self._best_persisted_checkpoint = persisted_checkpoint + else: + old_score = self._get_checkpoint_score(self._best_persisted_checkpoint) + candidate_score = self._get_checkpoint_score(persisted_checkpoint) + if candidate_score >= old_score: + self._best_persisted_checkpoint = persisted_checkpoint + + def _get_checkpoint_score( + self, checkpoint: TrackedCheckpoint + ) -> Tuple[bool, numbers.Number, int]: + checkpoint_score_attribute = ( + self._checkpoint_strategy.checkpoint_score_attribute + ) + if checkpoint_score_attribute not in checkpoint.result: + logger.error( + f"Result dict has no key: {checkpoint_score_attribute}. " + f"checkpoint_score_attr must be set to a key in the " + f"result dict. Valid keys are: {list(checkpoint.result.keys())}" + ) + checkpoint_result = float("-inf") + else: + checkpoint_result = checkpoint.result[checkpoint_score_attribute] + + checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order + if checkpoint_score_order == MAX: + order_factor = 1.0 + else: + order_factor = -1.0 + + checkpoint_score = order_factor * checkpoint_result + + if not isinstance(checkpoint_score, numbers.Number): + raise ValueError( + f"Unable to persist checkpoint for " + f"checkpoint_score_attribute: " + f"{checkpoint_score_attribute} with value " + f"{checkpoint_score}. " + f"This attribute must be numerical." + ) + + return ( + not is_nan(checkpoint_score), + checkpoint_score if not is_nan(checkpoint_score) else 0, + checkpoint.id, + ) + + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): + checkpoint_score = self._get_checkpoint_score(checkpoint) + wrapped_checkpoint = _HeapCheckpointWrapper( + priority=checkpoint_score, tracked_checkpoint=checkpoint + ) + + if self._checkpoint_strategy.num_to_keep is None: + # Keep all checkpoints + checkpoint.commit(path=self._get_next_checkpoint_path()) + self._replace_latest_persisted_checkpoint(checkpoint) + self._top_persisted_checkpoints.append(wrapped_checkpoint) + elif ( + len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep + ): + # Heap is not full yet, so keep this checkpoint + checkpoint.commit(path=self._get_next_checkpoint_path()) + heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) + self._replace_latest_persisted_checkpoint(checkpoint) + elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority: + # Priority is higher than current worst checkpoint, so replace worst + checkpoint.commit(path=self._get_next_checkpoint_path()) + worst_checkpoint = heapq.heappushpop( + self._top_persisted_checkpoints, wrapped_checkpoint + ).tracked_checkpoint + + # Only remove if checkpoint data is different + if worst_checkpoint.dir_or_data != checkpoint.dir_or_data: + self._maybe_delete_persisted_checkpoint(worst_checkpoint) + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") + + self._replace_latest_persisted_checkpoint(checkpoint) + else: + # If the latest checkpoint has the same or lower priority, skip it. + self._skip_persisted_checkpoint(checkpoint) + + self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) + self._cleanup_checkpoints() + + def _maybe_delete_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + if persisted_checkpoint == self._latest_persisted_checkpoint: + self._checkpoints_to_clean_up.add(persisted_checkpoint) + else: + self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) + + def _delete_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + persisted_checkpoint.delete(delete_fn=self._delete_fn) + self._checkpoints_to_clean_up.discard(persisted_checkpoint) + + def _cleanup_checkpoints(self): + for checkpoint in list(self._checkpoints_to_clean_up): + self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) + + def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") + self._checkpoints_to_clean_up.add(persisted_checkpoint) + + def _get_next_checkpoint_path(self) -> Optional[Path]: + return None + + def __del__(self): + self._cleanup_checkpoints() + + def __getstate__(self): + state = self.__dict__.copy() + + # Do not serialize the delete fn + state.pop("_delete_fn", None) + + # Avoid serializing the memory checkpoint. + state["_newest_memory_checkpoint"] = TrackedCheckpoint( + dir_or_data=None, + checkpoint_id=0, + storage_mode=TrackedCheckpoint.MEMORY, + ) + return state + + def __setstate__(self, state): + state["_delete_fn"] = None + self.__dict__.update(state) From 743ee438ecbca74051d4906a467451348bfa90bf Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:43:37 +0100 Subject: [PATCH 02/81] [tune/train] Consolidate checkpoint manager 2: Ray Train --- python/ray/train/checkpoint.py | 311 ++++++++++++++------------------- python/ray/train/trainer.py | 11 +- 2 files changed, 136 insertions(+), 186 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index dd03ed3197eb..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,27 +1,28 @@ -import heapq import logging -import numbers -import os -from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TIMESTAMP, TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import ( + TIMESTAMP, + TRAIN_CHECKPOINT_SUBDIR, + TUNE_CHECKPOINT_FILE_NAME, + TUNE_CHECKPOINT_ID, + TUNE_INSTALLED, +) from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util import PublicAPI -from ray.util.ml_utils.util import is_nan +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager as CommonCheckpointManager, + TrackedCheckpoint, + CheckpointStrategy, +) if TUNE_INSTALLED: from ray import tune else: tune = None -MAX = "max" -MIN = "min" - logger = logging.getLogger(__name__) @@ -34,64 +35,58 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -@PublicAPI(stability="beta") -@dataclass -class CheckpointStrategy: - """Configurable parameters for defining the Train checkpointing strategy. - - Default behavior is to persist all checkpoints to disk. If - ``num_to_keep`` is set, the default retention policy is to keep the - checkpoints with maximum timestamp, i.e. the most recent checkpoints. - - Args: - num_to_keep (Optional[int]): The number of checkpoints to keep - on disk for this run. If a checkpoint is persisted to disk after - there are already this many checkpoints, then an existing - checkpoint will be deleted. If this is ``None`` then checkpoints - will not be deleted. If this is ``0`` then no checkpoints will be - persisted to disk. - checkpoint_score_attribute (str): The attribute that will be used to - score checkpoints to determine which checkpoints should be kept - on disk when there are greater than ``num_to_keep`` checkpoints. - This attribute must be a key from the checkpoint - dictionary which has a numerical value. - checkpoint_score_order (str). Either "max" or "min". - If "max", then checkpoints with highest values of - ``checkpoint_score_attribute`` will be kept. - If "min", then checkpoints with lowest values of - ``checkpoint_score_attribute`` will be kept. +class _NotYetPersistedCheckpoint(TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. """ - num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TIMESTAMP - checkpoint_score_order: str = MAX + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def __post_init__(self): - if self.num_to_keep is not None and self.num_to_keep < 0: - raise ValueError( - f"Received invalidate num_to_keep: " - f"{self.num_to_keep}. " - f"Must be None or non-negative integer." - ) - if self.checkpoint_score_order not in (MAX, MIN): - raise ValueError( - f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' - ) + self._data_to_commit = self.dir_or_data + self.dir_or_data = None + @property + def committed(self) -> bool: + return not self._data_to_commit -class PersistedCheckpoint: - def __init__(self, path, priority): - self.path = path - self.priority = priority + def commit(self, path: Optional[Path] = None): + if self.committed: + return - def __lt__(self, other): - return self.priority < other.priority + assert path - def __repr__(self): - return f"PersistedCheckpoint({repr(self.path)})" + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self._data_to_commit, f) + logger.debug(f"Checkpoint successfully written to: {path}") + self.dir_or_data = path + self._data_to_commit = None -class CheckpointManager: + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): + if not self.committed: + return + return super().delete(delete_fn=delete_fn) + + @classmethod + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): + new_checkpoint = cls( + dir_or_data=checkpoint.dir_or_data, + storage_mode=TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + ) + return new_checkpoint + + +class CheckpointManager(CommonCheckpointManager): """Manages checkpoint processing, writing, and loading. @@ -119,55 +114,16 @@ class CheckpointManager: checkpoint may not be saved to disk. """ - def on_init(self, **kwargs): - """Checkpoint code executed during BackendExecutor init.""" - self.latest_checkpoint = None - - # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = 0 - - # Used for keeping top K checkpoints. - self._top_persisted_checkpoints = [] - - # Best checkpoint altogether. - # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint = None - - def on_start_training( - self, - checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: Path, - latest_checkpoint_id: Optional[int] = None, - ): - """Checkpoint code executed during BackendExecutor start_training.""" - # Restart checkpointing. - self._latest_checkpoint_id = latest_checkpoint_id if latest_checkpoint_id else 0 - self._checkpoint_strategy = ( - CheckpointStrategy() if checkpoint_strategy is None else checkpoint_strategy - ) + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Perform all processing for a checkpoint.""" - - # Get checkpoint from first worker. - checkpoint = checkpoint_results[0].data + super().__init__(checkpoint_strategy=checkpoint_strategy) - # Decode checkpoint. - checkpoint = decode_checkpoint_fn(checkpoint) + self._validate_checkpoint_strategy() - # Store checkpoint in memory. - self.latest_checkpoint = checkpoint - - # Write checkpoint to disk. - self.write_checkpoint(checkpoint) - - # Increment checkpoint id. - self._latest_checkpoint_id += 1 + def _validate_checkpoint_strategy(self): + if self._checkpoint_strategy.checkpoint_score_attribute is None: + self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]] @@ -181,94 +137,77 @@ def _load_checkpoint( # Load checkpoint from path. return load_checkpoint_from_path(checkpoint_to_load) - def write_checkpoint(self, checkpoint: Dict): - """Writes checkpoint to disk.""" - num_to_keep = self._checkpoint_strategy.num_to_keep + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Ray Train entrypoint. Perform all processing for a checkpoint.""" + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data - if num_to_keep == 0: - # Checkpoints should not be persisted to disk. - return + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) - checkpoint_score_attribute = ( - self._checkpoint_strategy.checkpoint_score_attribute - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_attribute not in checkpoint: + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + if ( + self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_data + ): raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " + f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) - checkpoint_score = checkpoint[checkpoint_score_attribute] - if not isinstance(checkpoint_score, numbers.Number): - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute} with value " - f"{checkpoint_score}. " - f"This attribute must be numerical." + tracked_checkpoint = TrackedCheckpoint( + dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + ) + self.register_checkpoint(checkpoint=tracked_checkpoint) + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - def priority(checkpoint_score_order, checkpoint_score): - # Treat NaN as worst - # The tuple structure is (not is_nan(), metric), which makes - # the nan values to be always considered as the worst - # metrics by the heap - if checkpoint_score_order != MAX: - checkpoint_score = -checkpoint_score - return (not is_nan(checkpoint_score), checkpoint_score) + self._latest_checkpoint_id += 1 + + def _get_next_checkpoint_path(self) -> Optional[Path]: + """Path to the next checkpoint to persist.""" + checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + return self.latest_checkpoint_dir.joinpath(checkpoint_file) - checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) + def on_start_training( + self, + checkpoint_strategy: Optional[CheckpointStrategy], + run_dir: str, + latest_checkpoint_id: Optional[int] = 0, + ): + checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self._checkpoint_strategy = checkpoint_strategy - persisted_checkpoint = PersistedCheckpoint( - self.next_checkpoint_path, checkpoint_priority - ) + self._validate_checkpoint_strategy() - def write_to_disk(path: Path): - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(checkpoint, f) - logger.debug(f"Checkpoint successfully written to: " f"{path}") - - def remove_from_disk(path: Path): - os.remove(path) - - if num_to_keep is None: - # Keep all checkpoints. - write_to_disk(self.next_checkpoint_path) - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep first num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) - elif ( - persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority - ): - # Keep top num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, persisted_checkpoint - ) - worst_checkpoint_path = worst_checkpoint.path - remove_from_disk(worst_checkpoint_path) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") - else: - # If the latest checkpoint has the same or lower priority, skip it. - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." - ) + self.run_dir = run_dir + self._latest_checkpoint_id = latest_checkpoint_id or 0 - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint = persisted_checkpoint + # Train-specific attributes + @property + def latest_checkpoint(self): + if not self._latest_memory_checkpoint: + return None + return self._latest_memory_checkpoint.dir_or_data @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -294,7 +233,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return self._best_persisted_checkpoint.path + return Path(self._best_persisted_checkpoint.dir_or_data) else: return None @@ -328,16 +267,22 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - with file_path.open("wb") as f: - cloudpickle.dump(checkpoint, f) + checkpoint.commit(file_path) + + return super()._decide_what_to_do_with_checkpoint(checkpoint) def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 798dd99454a0..efb39e9c3535 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,11 +224,16 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) + # Todo (krfricke): Initialize checkpoint manager here with final values + # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager() + self.checkpoint_manager = TuneCheckpointManager( + checkpoint_strategy=None, run_dir=None + ) else: - self.checkpoint_manager = CheckpointManager() - self.checkpoint_manager.on_init() + self.checkpoint_manager = CheckpointManager( + checkpoint_strategy=None, run_dir=None + ) def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From cfeefea3cb423129750897890f14323444c1fc69 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 6 Apr 2022 16:32:29 -0700 Subject: [PATCH 03/81] WIP --- python/ray/train/checkpoint.py | 270 +++++++------ python/ray/tune/checkpoint_manager.py | 72 ++-- .../ray/util/ml_utils/checkpoint_manager.py | 361 +++++------------- 3 files changed, 263 insertions(+), 440 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ffbddb5bd4a8..7809072934a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,22 +1,17 @@ +import heapq import logging +import numbers +import os from pathlib import Path from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import ( - TIMESTAMP, - TRAIN_CHECKPOINT_SUBDIR, - TUNE_CHECKPOINT_FILE_NAME, - TUNE_CHECKPOINT_ID, - TUNE_INSTALLED, -) +from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util.ml_utils.checkpoint_manager import ( - CheckpointManager as CommonCheckpointManager, - TrackedCheckpoint, - CheckpointStrategy, -) +from ray.util.ml_utils.checkpoint_manager import MAX, \ + CheckpointManager as CommonCheckpointManager, _TrackedCheckpoint if TUNE_INSTALLED: from ray import tune @@ -35,58 +30,66 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -class _NotYetPersistedCheckpoint(TrackedCheckpoint): - """Tracked checkpoint that is not yet persisted to disk. +class _NotYetPersistedCheckpoint(_TrackedCheckpoint): + def commit(self): + # Todo: write self.checkpoint_data_or_dict to disk + pass - This checkpoint class supports lazy writing. The checkpoint manager will - only call ``commit()`` if the checkpoint should be kept on disk. This class - will only then write checkpoint data to disk. - """ + def delete(self): + pass - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._data_to_commit = self.dir_or_data - self.dir_or_data = None +class CheckpointManager(CommonCheckpointManager): + def __init__(self, run_dir: str): + self.run_dir = run_dir - @property - def committed(self) -> bool: - return not self._data_to_commit + def _load_checkpoint( + self, checkpoint_to_load: Optional[Union[Dict, str, Path]] + ) -> Optional[Dict]: + """Load the checkpoint dictionary from the input dict or path.""" + if checkpoint_to_load is None: + return None + if isinstance(checkpoint_to_load, Dict): + return checkpoint_to_load + else: + # Load checkpoint from path. + return load_checkpoint_from_path(checkpoint_to_load) - def commit(self, path: Optional[Path] = None): - if self.committed: - return + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Perform all processing for a checkpoint.""" - assert path + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(self._data_to_commit, f) - logger.debug(f"Checkpoint successfully written to: {path}") + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) - self.dir_or_data = path - self._data_to_commit = None + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + tracked_checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=_TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data[score_attr]}, - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - if not self.committed: - return - return super().delete(delete_fn=delete_fn) - - @classmethod - def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): - new_checkpoint = cls( - dir_or_data=checkpoint.dir_or_data, - storage_mode=TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.id, - result=checkpoint.result, - node_ip=checkpoint.node_ip, ) - return new_checkpoint + self.decide_what_to_do_with_checkpoint(tracked_checkpoint) -class CheckpointManager(CommonCheckpointManager): + # Store checkpoint in memory. + self.latest_checkpoint = checkpoint_data + + # Write checkpoint to disk. + self.write_checkpoint(checkpoint_data) + + # Increment checkpoint id. + self._latest_checkpoint_id += 1 + + +class CheckpointManagerLegacy: """Manages checkpoint processing, writing, and loading. @@ -114,100 +117,93 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ - def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): - self.run_dir = run_dir - - super().__init__(checkpoint_strategy=checkpoint_strategy) - - self._validate_checkpoint_strategy() - - def _validate_checkpoint_strategy(self): - if self._checkpoint_strategy.checkpoint_score_attribute is None: - self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP - def _load_checkpoint( - self, checkpoint_to_load: Optional[Union[Dict, str, Path]] - ) -> Optional[Dict]: - """Load the checkpoint dictionary from the input dict or path.""" - if checkpoint_to_load is None: - return None - if isinstance(checkpoint_to_load, Dict): - return checkpoint_to_load - else: - # Load checkpoint from path. - return load_checkpoint_from_path(checkpoint_to_load) - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Ray Train entrypoint. Perform all processing for a checkpoint.""" - # Get checkpoint from first worker. - checkpoint_data = checkpoint_results[0].data + def write_checkpoint(self, checkpoint: Dict): + """Writes checkpoint to disk.""" + num_to_keep = self._checkpoint_strategy.num_to_keep - # Decode checkpoint. - checkpoint_data = decode_checkpoint_fn(checkpoint_data) + if num_to_keep == 0: + # Checkpoints should not be persisted to disk. + return - score_attr = self._checkpoint_strategy.checkpoint_score_attribute - if ( - self._checkpoint_strategy.num_to_keep != 0 - and score_attr not in checkpoint_data - ): + checkpoint_score_attribute = ( + self._checkpoint_strategy.checkpoint_score_attribute + ) + checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order + if checkpoint_score_attribute not in checkpoint: raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{score_attr}. " + f"{checkpoint_score_attribute}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) + checkpoint_score = checkpoint[checkpoint_score_attribute] - tracked_checkpoint = TrackedCheckpoint( - dir_or_data=checkpoint_data, - checkpoint_id=self._latest_checkpoint_id, - storage_mode=TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data.get(score_attr, 0.0)}, - ) - self.register_checkpoint(checkpoint=tracked_checkpoint) - - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(checkpoint) - - # Only process further if we consider keeping this checkpoint on disk - if self._checkpoint_strategy.num_to_keep != 0: - not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) + if not isinstance(checkpoint_score, numbers.Number): + raise ValueError( + f"Unable to persist checkpoint for " + f"checkpoint_score_attribute: " + f"{checkpoint_score_attribute} with value " + f"{checkpoint_score}. " + f"This attribute must be numerical." ) - self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - self._latest_checkpoint_id += 1 + def priority(checkpoint_score_order, checkpoint_score): + if checkpoint_score_order == MAX: + return checkpoint_score + else: + return -checkpoint_score - def _get_next_checkpoint_path(self) -> Optional[Path]: - """Path to the next checkpoint to persist.""" - checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) - return self.latest_checkpoint_dir.joinpath(checkpoint_file) + checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) - def on_start_training( - self, - checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: str, - latest_checkpoint_id: Optional[int] = 0, - ): - checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() - self._checkpoint_strategy = checkpoint_strategy - - self._validate_checkpoint_strategy() + persisted_checkpoint = PersistedCheckpoint( + self.next_checkpoint_path, checkpoint_priority + ) - self.run_dir = run_dir - self._latest_checkpoint_id = latest_checkpoint_id or 0 + def write_to_disk(path: Path): + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(checkpoint, f) + logger.debug(f"Checkpoint successfully written to: " f"{path}") + + def remove_from_disk(path: Path): + os.remove(path) + + if num_to_keep is None: + # Keep all checkpoints. + write_to_disk(self.next_checkpoint_path) + elif len(self._top_persisted_checkpoints) < num_to_keep: + # Keep first num_to_keep checkpoints. + write_to_disk(self.next_checkpoint_path) + heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) + elif ( + persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority + ): + # Keep top num_to_keep checkpoints. + write_to_disk(self.next_checkpoint_path) + worst_checkpoint = heapq.heappushpop( + self._top_persisted_checkpoints, persisted_checkpoint + ) + worst_checkpoint_path = worst_checkpoint.path + remove_from_disk(worst_checkpoint_path) + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") + else: + # If the latest checkpoint has the same or lower priority, skip it. + logger.debug( + f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." + ) - # Train-specific attributes - @property - def latest_checkpoint(self): - if not self._latest_memory_checkpoint: - return None - return self._latest_memory_checkpoint.dir_or_data + # Update single best checkpoint. + if ( + self._best_persisted_checkpoint is None + or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority + ): + # If the latest checkpoint has the same or lower priority, skip it. + self._best_persisted_checkpoint = persisted_checkpoint @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -233,7 +229,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return Path(self._best_persisted_checkpoint.dir_or_data) + return self._best_persisted_checkpoint.path else: return None @@ -267,22 +263,16 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint( - self, checkpoint: _NotYetPersistedCheckpoint - ): - assert isinstance(checkpoint, _NotYetPersistedCheckpoint) - assert not checkpoint.committed - - self.add_tune_checkpoint_id(checkpoint._data_to_commit) + def write_checkpoint(self, checkpoint: Dict): + self.add_tune_checkpoint_id(checkpoint) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - checkpoint.commit(file_path) - - return super()._decide_what_to_do_with_checkpoint(checkpoint) + with file_path.open("wb") as f: + cloudpickle.dump(checkpoint, f) def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 75cf4b8cb835..fc42371651c7 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -1,11 +1,13 @@ # coding: utf-8 -import heapq import gc +import heapq import logging from typing import Any, Callable, Optional from ray.tune.result import NODE_IP -from ray.tune.utils.util import flatten_dict, is_nan +from ray.tune.utils.util import flatten_dict +from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, MIN, MAX, \ + CheckpointManager as CommonCheckpointManager logger = logging.getLogger(__name__) @@ -74,7 +76,29 @@ def __repr__(self): return f"QueueItem({repr(self.value)})" -class CheckpointManager: +class CheckpointManager(CommonCheckpointManager): + def __init__( + self, + keep_checkpoints_num: int, + checkpoint_score_attr: str, + delete_fn: Callable[[str], None], + ): + checkpoint_score_desc = checkpoint_score_attr.startswith("min-") + if checkpoint_score_desc: + checkpoint_score_attr = checkpoint_score_attr[4:] + else: + checkpoint_score_attr = checkpoint_score_attr + + checkpoint_strategy = CheckpointStrategy( + num_to_keep=keep_checkpoints_num, + checkpoint_score_attribute=checkpoint_score_attr, + checkpoint_score_order=MIN if checkpoint_score_desc else MAX + ) + + def + + +class CheckpointManagerLegacy: """Manages checkpoints on the driver for a trial.""" def __init__( @@ -106,22 +130,14 @@ def __init__( self._checkpoint_score_attr = checkpoint_score_attr self.delete = delete_fn - self._newest_persistent_checkpoint = None + self.newest_persistent_checkpoint = _TuneCheckpoint( + _TuneCheckpoint.PERSISTENT, None + ) self._newest_memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, None) self._best_checkpoints = [] self._membership = set() self._cur_order = 0 - @property - def newest_persistent_checkpoint(self): - return self._newest_persistent_checkpoint or _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, None - ) - - @newest_persistent_checkpoint.setter - def newest_persistent_checkpoint(self, value): - self._newest_persistent_checkpoint = value - @property def newest_checkpoint(self): """Returns the newest checkpoint (based on training iteration).""" @@ -162,9 +178,9 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): self.replace_newest_memory_checkpoint(checkpoint) return - old_checkpoint = self._newest_persistent_checkpoint + old_checkpoint = self.newest_persistent_checkpoint - if old_checkpoint and old_checkpoint.value == checkpoint.value: + if old_checkpoint.value == checkpoint.value: # Overwrite the order of the checkpoint. old_checkpoint.order = checkpoint.order return @@ -172,26 +188,16 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): self.newest_persistent_checkpoint = checkpoint # Remove the old checkpoint if it isn't one of the best ones. - if ( - old_checkpoint - and old_checkpoint.value - and old_checkpoint not in self._membership - ): + if old_checkpoint.value and old_checkpoint not in self._membership: self.delete(old_checkpoint) try: - # NaN metrics are treated as worst checkpoint - # The tuple structure is (not is_nan(), metric), which makes - # the nan values to be always considered as the worst - # metrics by the heap queue_item = QueueItem(self._priority(checkpoint), checkpoint) except KeyError: logger.error( "Result dict has no key: {}. " - "checkpoint_score_attr must be set to a key of the " - "result dict. Valid keys are {}".format( - self._checkpoint_score_attr, list(checkpoint.result.keys()) - ) + "checkpoint_score_attr must be set to a key in the " + "result dict.".format(self._checkpoint_score_attr) ) return @@ -216,13 +222,7 @@ def best_checkpoints(self): def _priority(self, checkpoint): result = flatten_dict(checkpoint.result) priority = result[self._checkpoint_score_attr] - if self._checkpoint_score_desc: - priority = -priority - return ( - not is_nan(priority), - priority if not is_nan(priority) else 0, - checkpoint.order, - ) + return -priority if self._checkpoint_score_desc else priority def __getstate__(self): state = self.__dict__.copy() diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index f97068b58377..63e6f7afc9d6 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,19 +1,19 @@ -import gc import heapq import logging import numbers -import os -import shutil - from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Union, Callable, Tuple, List, Any +from typing import Optional, Dict, Union import ray +from ray.train.constants import TIMESTAMP, TUNE_INSTALLED from ray.tune.result import NODE_IP from ray.util import PublicAPI -from ray.util.annotations import DeveloperAPI -from ray.util.ml_utils.util import is_nan + +if TUNE_INSTALLED: + pass +else: + tune = None MAX = "max" MIN = "min" @@ -21,96 +21,34 @@ logger = logging.getLogger(__name__) -@DeveloperAPI -class TrackedCheckpoint: - """Checkpoint tracked by a checkpoint manager. - - This class is used to track checkpoints generated by trainables and trainers in - order to add metadata (e.g. the result, or the node where it has been created) - and for bookkeeping purposes. - - Args: - dir_or_data: Checkpoint directory, checkpoint data, or a future to either. - storage_mode: Either MEMORY or PERSISTENT. - checkpoint_id: Checkpoint number. Usually this should be monotonically - increasing for each tracked checkpoint. - result: Observed metrics for this checkpoint. This is used to determine - the value of the ``checkpoint_score_attr``. - node_ip: IP of the node where the checkpoint was generated. Defaults - to the current node. - """ - +class _TrackedCheckpoint: MEMORY = "memory" PERSISTENT = "persistent" - def __init__( - self, - dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], - storage_mode: str, - checkpoint_id: Optional[int] = None, - result: Optional[Dict] = None, - node_ip: Optional[str] = None, - ): - self.dir_or_data = dir_or_data - self.id = checkpoint_id + def __init__(self, + checkpoint_dir_or_data: Union[str, Path, Dict, ray.ObjectRef], + checkpoint_id: int, + storage_mode: str, + result: Optional[Dict] = None, + node_ip: Optional[str] = None, + ): + self.checkpoint_dir_or_data = checkpoint_dir_or_data + self.checkpoint_id = checkpoint_id self.storage_mode = storage_mode - # Todo: What to do if result is a subset of dir_or_data (dict) + # Todo: What to do if result is a subset of checkpoint_dir_or_data (dict) self.result = result or {} self.node_ip = node_ip or self.result.get(NODE_IP, None) - def commit(self, path: Optional[Path] = None) -> None: - """Commit checkpoint to disk, if needed. - - Args: - path: Path to commit checkpoint to. - """ + def commit(self): pass - def delete( - self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None - ) -> None: - """Delete checkpoint from disk, if needed. - - Args: - delete_fn: Function to be called with the tracked checkpoint as an - argument. Defaults to removing the local directory/file. - """ - delete_fn = delete_fn or _default_delete_fn - try: - delete_fn(self) - except Exception as e: - logger.warning(f"Checkpoint deletion failed: {e}") - - def __repr__(self): - if self.storage_mode == TrackedCheckpoint.MEMORY: - return f"" - - return ( - f"" - ) - - -def _default_delete_fn(checkpoint: TrackedCheckpoint): - if checkpoint.storage_mode != TrackedCheckpoint.PERSISTENT: - return - - if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): - if os.path.isfile(checkpoint.dir_or_data): - os.remove(checkpoint.dir_or_data) - return - elif os.path.isdir(checkpoint.dir_or_data): - shutil.rmtree(checkpoint.dir_or_data) - return - raise RuntimeError( - f"Could not delete checkpoint {checkpoint} from disk as it is " - f"neither file not directory. Path: {checkpoint.dir_or_data}." - ) + def delete(self): + pass class _HeapCheckpointWrapper: - def __init__(self, priority: Any, tracked_checkpoint: TrackedCheckpoint): + def __init__(self, priority: numbers.Number, tracked_checkpoint: _TrackedCheckpoint): self.priority = priority self.tracked_checkpoint = tracked_checkpoint @@ -124,7 +62,7 @@ def __repr__(self): @PublicAPI(stability="beta") @dataclass class CheckpointStrategy: - """Configurable parameters for defining the checkpointing strategy. + """Configurable parameters for defining the Train checkpointing strategy. Default behavior is to persist all checkpoints to disk. If ``num_to_keep`` is set, the default retention policy is to keep the @@ -141,8 +79,7 @@ class CheckpointStrategy: score checkpoints to determine which checkpoints should be kept on disk when there are greater than ``num_to_keep`` checkpoints. This attribute must be a key from the checkpoint - dictionary which has a numerical value. Per default, the last - checkpoints will be kept. + dictionary which has a numerical value. checkpoint_score_order (str). Either "max" or "min". If "max", then checkpoints with highest values of ``checkpoint_score_attribute`` will be kept. @@ -151,13 +88,13 @@ class CheckpointStrategy: """ num_to_keep: Optional[int] = None - checkpoint_score_attribute: Optional[str] = None + checkpoint_score_attribute: str = TIMESTAMP checkpoint_score_order: str = MAX def __post_init__(self): if self.num_to_keep is not None and self.num_to_keep < 0: raise ValueError( - f"Received invalid num_to_keep: " + f"Received invalidate num_to_keep: " f"{self.num_to_keep}. " f"Must be None or non-negative integer." ) @@ -168,117 +105,67 @@ def __post_init__(self): class CheckpointManager: - """Common checkpoint management and bookkeeping class for Ray Train and Tune. + def __init__(self, checkpoint_strategy: CheckpointStrategy, latest_checkpoint_id: int = 0): + self._latest_checkpoint_id = latest_checkpoint_id # Todo (krfricke): Review if needed + self._checkpoint_strategy = checkpoint_strategy - This class acts as the common core for checkpoint bookkeeping in Ray ML libraries. - On a high level, this manager keeps a reference to all stored checkpoints - (both in-memory and on-disk checkpoints). For on-disk checkpoints, it - keeps a configured number of checkpoints according to specified metrics. - - The manager supports lazy data writing by utilizing the - ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint - should be persisted to disk. - """ - - def __init__( - self, - checkpoint_strategy: CheckpointStrategy, - latest_checkpoint_id: int = 0, - delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, - ): - self._checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self.latest_checkpoint = None # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = latest_checkpoint_id + self._latest_checkpoint_id = 0 # Used for keeping top K checkpoints. - self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] + self._top_persisted_checkpoints = [] # Best checkpoint altogether. # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint: Optional[TrackedCheckpoint] = None - self._latest_persisted_checkpoint: Optional[TrackedCheckpoint] = None - self._latest_memory_checkpoint: Optional[TrackedCheckpoint] = None - - # Checkpoints that are not immediately removed - self._checkpoints_to_clean_up = set() - self._delete_fn = delete_fn - - def set_delete_fn(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]]): - """Update the function called to delete persisted checkpoints. - - Args: - delete_fn: Function that takes a tracked checkpoint as an argument and - deletes it from disk. - """ - self._delete_fn = delete_fn - - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - """Register new checkpoint and add to bookkeeping. - - This method will register a new checkpoint and add it to the internal - bookkeeping logic. This means the checkpoint manager will decide if - this checkpoint should be kept, and if older or worse performing - checkpoints should be deleted. - - Subclasses have to implement this method. - - Args: - checkpoint: Tracked checkpoint object to add to bookkeeping. - """ - raise NotImplementedError - - def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): - assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY - self._latest_memory_checkpoint = memory_checkpoint - # Avoid memory leaks on k8s pods - gc.collect() - - def _replace_latest_persisted_checkpoint( - self, persisted_checkpoint: TrackedCheckpoint - ): - second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint - self._latest_persisted_checkpoint = persisted_checkpoint - - if self._checkpoint_strategy.num_to_keep == 0: - self._maybe_delete_persisted_checkpoint( - second_to_latest_persisted_checkpoint - ) + self._best_persisted_checkpoint_wrapped: Optional[_HeapCheckpointWrapper] = None + + self._last_persisted_checkpoint: Optional[_TrackedCheckpoint] = None + + self._last_memory_checkpoint: Optional[_TrackedCheckpoint] = None + + # Do we need this at all? + @property + def _best_persisted_checkpoint(self) -> _TrackedCheckpoint: + return self._best_persisted_checkpoint_wrapped.tracked_checkpoint + + def decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: + self._last_memory_checkpoint = checkpoint + + second_to_last_persisted_checkpoint = self._last_persisted_checkpoint + self._last_persisted_checkpoint = checkpoint + + num_to_keep = self._checkpoint_strategy.num_to_keep + + if num_to_keep == 0: + if second_to_last_persisted_checkpoint: + second_to_last_persisted_checkpoint.delete() + + # Checkpoints should not be persisted to disk. + return - def _maybe_replace_best_persisted_checkpoint( - self, persisted_checkpoint: TrackedCheckpoint - ): - if self._best_persisted_checkpoint is None: - self._best_persisted_checkpoint = persisted_checkpoint - else: - old_score = self._get_checkpoint_score(self._best_persisted_checkpoint) - candidate_score = self._get_checkpoint_score(persisted_checkpoint) - if candidate_score >= old_score: - self._best_persisted_checkpoint = persisted_checkpoint - - def _get_checkpoint_score( - self, checkpoint: TrackedCheckpoint - ) -> Tuple[bool, numbers.Number, int]: checkpoint_score_attribute = ( self._checkpoint_strategy.checkpoint_score_attribute ) - if checkpoint_score_attribute not in checkpoint.result: - logger.error( - f"Result dict has no key: {checkpoint_score_attribute}. " - f"checkpoint_score_attr must be set to a key in the " - f"result dict. Valid keys are: {list(checkpoint.result.keys())}" + + if checkpoint_score_attribute not in checkpoint: + raise ValueError( + f"Unable to persist checkpoint for " + f"checkpoint_score_attribute: " + f"{checkpoint_score_attribute}. " + f"Include this attribute in the call to " + f"train.save_checkpoint." ) - checkpoint_result = float("-inf") - else: - checkpoint_result = checkpoint.result[checkpoint_score_attribute] checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_order == MAX: - order_factor = 1.0 + if checkpoint_score_order == MIN: + order_factor = 1. else: - order_factor = -1.0 + order_factor = -1. - checkpoint_score = order_factor * checkpoint_result + checkpoint_score = order_factor * checkpoint.result[checkpoint_score_attribute] if not isinstance(checkpoint_score, numbers.Number): raise ValueError( @@ -289,90 +176,36 @@ def _get_checkpoint_score( f"This attribute must be numerical." ) - return ( - not is_nan(checkpoint_score), - checkpoint_score if not is_nan(checkpoint_score) else 0, - checkpoint.id, - ) - - def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): - checkpoint_score = self._get_checkpoint_score(checkpoint) - wrapped_checkpoint = _HeapCheckpointWrapper( - priority=checkpoint_score, tracked_checkpoint=checkpoint - ) + wrapped_checkpoint = _HeapCheckpointWrapper(priority=checkpoint_score, tracked_checkpoint=checkpoint) - if self._checkpoint_strategy.num_to_keep is None: - # Keep all checkpoints - checkpoint.commit(path=self._get_next_checkpoint_path()) - self._replace_latest_persisted_checkpoint(checkpoint) - self._top_persisted_checkpoints.append(wrapped_checkpoint) + if num_to_keep is None: + # Keep all checkpoints. + checkpoint.commit() + # Todo: track as latest available persisted checkpoint + pass + elif len(self._top_persisted_checkpoints) < num_to_keep: + # Keep this checkpoint + checkpoint.commit() + heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) elif ( - len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep + wrapped_checkpoint.priority > self._top_persisted_checkpoints[0].priority ): - # Heap is not full yet, so keep this checkpoint - checkpoint.commit(path=self._get_next_checkpoint_path()) - heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) - self._replace_latest_persisted_checkpoint(checkpoint) - elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority: - # Priority is higher than current worst checkpoint, so replace worst - checkpoint.commit(path=self._get_next_checkpoint_path()) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, wrapped_checkpoint - ).tracked_checkpoint - - # Only remove if checkpoint data is different - if worst_checkpoint.dir_or_data != checkpoint.dir_or_data: - self._maybe_delete_persisted_checkpoint(worst_checkpoint) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") - - self._replace_latest_persisted_checkpoint(checkpoint) + # Write checkpoint to disk if not yet persisted + checkpoint.commit() + worst_checkpoint = heapq.heappushpop(self._top_persisted_checkpoints, wrapped_checkpoint).tracked_checkpoint + worst_checkpoint.delete() + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") else: # If the latest checkpoint has the same or lower priority, skip it. - self._skip_persisted_checkpoint(checkpoint) - - self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) - self._cleanup_checkpoints() - - def _maybe_delete_persisted_checkpoint( - self, persisted_checkpoint: TrackedCheckpoint - ): - if persisted_checkpoint == self._latest_persisted_checkpoint: - self._checkpoints_to_clean_up.add(persisted_checkpoint) - else: - self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) - - def _delete_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): - persisted_checkpoint.delete(delete_fn=self._delete_fn) - self._checkpoints_to_clean_up.discard(persisted_checkpoint) - - def _cleanup_checkpoints(self): - for checkpoint in list(self._checkpoints_to_clean_up): - self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) - - def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): - logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") - self._checkpoints_to_clean_up.add(persisted_checkpoint) - - def _get_next_checkpoint_path(self) -> Optional[Path]: - return None - - def __del__(self): - self._cleanup_checkpoints() - - def __getstate__(self): - state = self.__dict__.copy() - - # Do not serialize the delete fn - state.pop("_delete_fn", None) - - # Avoid serializing the memory checkpoint. - state["_newest_memory_checkpoint"] = TrackedCheckpoint( - dir_or_data=None, - checkpoint_id=0, - storage_mode=TrackedCheckpoint.MEMORY, - ) - return state + # Todo: fix this + logger.debug( + f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." + ) - def __setstate__(self, state): - state["_delete_fn"] = None - self.__dict__.update(state) + # Update single best checkpoint. + if ( + self._best_persisted_checkpoint is None + or wrapped_checkpoint.priority > self._best_persisted_checkpoint_wrapped.priority + ): + # If the latest checkpoint has the same or lower priority, skip it. + self._best_persisted_checkpoint_wrapped = checkpoint From bd0858dd66ba8e12bf68a3f0d1f07f6d8dcc4ddb Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 3 May 2022 13:33:31 +0100 Subject: [PATCH 04/81] Continue consolidation --- python/ray/train/checkpoint.py | 214 +++++++---------- python/ray/tune/checkpoint_manager.py | 187 ++++++--------- python/ray/tune/tests/test_trial_runner_2.py | 4 +- python/ray/tune/trial.py | 4 +- .../ray/util/ml_utils/checkpoint_manager.py | 216 +++++++++++------- 5 files changed, 292 insertions(+), 333 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 7809072934a8..1e183762a756 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,7 +1,4 @@ -import heapq import logging -import numbers -import os from pathlib import Path from typing import List, Optional, Dict, Union, Callable @@ -10,8 +7,11 @@ from ray.train.constants import TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util.ml_utils.checkpoint_manager import MAX, \ - CheckpointManager as CommonCheckpointManager, _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager as CommonCheckpointManager, + _TrackedCheckpoint, + CheckpointStrategy, +) if TUNE_INSTALLED: from ray import tune @@ -31,65 +31,40 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: class _NotYetPersistedCheckpoint(_TrackedCheckpoint): - def commit(self): - # Todo: write self.checkpoint_data_or_dict to disk - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def delete(self): - pass - - -class CheckpointManager(CommonCheckpointManager): - def __init__(self, run_dir: str): - self.run_dir = run_dir + self._committed = False - def _load_checkpoint( - self, checkpoint_to_load: Optional[Union[Dict, str, Path]] - ) -> Optional[Dict]: - """Load the checkpoint dictionary from the input dict or path.""" - if checkpoint_to_load is None: - return None - if isinstance(checkpoint_to_load, Dict): - return checkpoint_to_load - else: - # Load checkpoint from path. - return load_checkpoint_from_path(checkpoint_to_load) + def commit(self, path: Optional[Path] = None): + if self._committed: + return - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Perform all processing for a checkpoint.""" + assert path - # Get checkpoint from first worker. - checkpoint_data = checkpoint_results[0].data + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self, f) + logger.debug(f"Checkpoint successfully written to: {path}") - # Decode checkpoint. - checkpoint_data = decode_checkpoint_fn(checkpoint_data) + self._committed = True - score_attr = self._checkpoint_strategy.checkpoint_score_attribute - tracked_checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=checkpoint_data, - checkpoint_id=self._latest_checkpoint_id, - storage_mode=_TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data[score_attr]}, + def delete(self): + if not self._committed: + return + return super().delete() + @classmethod + def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): + new_checkpoint = cls( + **{**checkpoint.__dict__, "storage_mode": _TrackedCheckpoint.PERSISTENT} ) + return new_checkpoint - self.decide_what_to_do_with_checkpoint(tracked_checkpoint) - # Store checkpoint in memory. - self.latest_checkpoint = checkpoint_data - - # Write checkpoint to disk. - self.write_checkpoint(checkpoint_data) - - # Increment checkpoint id. - self._latest_checkpoint_id += 1 - - -class CheckpointManagerLegacy: +class CheckpointManager(CommonCheckpointManager): """Manages checkpoint processing, writing, and loading. @@ -117,94 +92,72 @@ class CheckpointManagerLegacy: checkpoint may not be saved to disk. """ + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): + self.run_dir = run_dir + super().__init__(checkpoint_strategy=checkpoint_strategy) - def write_checkpoint(self, checkpoint: Dict): - """Writes checkpoint to disk.""" - num_to_keep = self._checkpoint_strategy.num_to_keep + def _load_checkpoint( + self, checkpoint_to_load: Optional[Union[Dict, str, Path]] + ) -> Optional[Dict]: + """Load the checkpoint dictionary from the input dict or path.""" + if checkpoint_to_load is None: + return None + if isinstance(checkpoint_to_load, Dict): + return checkpoint_to_load + else: + # Load checkpoint from path. + return load_checkpoint_from_path(checkpoint_to_load) - if num_to_keep == 0: - # Checkpoints should not be persisted to disk. - return + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Perform all processing for a checkpoint.""" - checkpoint_score_attribute = ( - self._checkpoint_strategy.checkpoint_score_attribute - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_attribute not in checkpoint: + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data + + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) + + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + if ( + self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_data + ): raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " + f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) - checkpoint_score = checkpoint[checkpoint_score_attribute] - - if not isinstance(checkpoint_score, numbers.Number): - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute} with value " - f"{checkpoint_score}. " - f"This attribute must be numerical." - ) - def priority(checkpoint_score_order, checkpoint_score): - if checkpoint_score_order == MAX: - return checkpoint_score - else: - return -checkpoint_score - - checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) - - persisted_checkpoint = PersistedCheckpoint( - self.next_checkpoint_path, checkpoint_priority + tracked_checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=_TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) - def write_to_disk(path: Path): - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(checkpoint, f) - logger.debug(f"Checkpoint successfully written to: " f"{path}") - - def remove_from_disk(path: Path): - os.remove(path) - - if num_to_keep is None: - # Keep all checkpoints. - write_to_disk(self.next_checkpoint_path) - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep first num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) - elif ( - persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority - ): - # Keep top num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, persisted_checkpoint - ) - worst_checkpoint_path = worst_checkpoint.path - remove_from_disk(worst_checkpoint_path) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") - else: - # If the latest checkpoint has the same or lower priority, skip it. - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(tracked_checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(tracked_checkpoint) ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint = persisted_checkpoint + def _get_next_checkpoint_path(self) -> Optional[Path]: + """Path to the next checkpoint to persist.""" + checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + return self.latest_checkpoint_dir.joinpath(checkpoint_file) + # Train-specific attributes @property def latest_checkpoint_dir(self) -> Optional[Path]: """Path to the latest checkpoint directory.""" @@ -229,7 +182,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return self._best_persisted_checkpoint.path + return Path(self._best_persisted_checkpoint.checkpoint_dir_or_data) else: return None @@ -263,8 +216,10 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + self.add_tune_checkpoint_id(checkpoint.checkpoint_dir_or_data) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) @@ -273,6 +228,7 @@ def write_checkpoint(self, checkpoint: Dict): file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) with file_path.open("wb") as f: cloudpickle.dump(checkpoint, f) + checkpoint._committed = True def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index fc42371651c7..3345d7b90d4d 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -1,13 +1,15 @@ # coding: utf-8 -import gc -import heapq import logging from typing import Any, Callable, Optional from ray.tune.result import NODE_IP -from ray.tune.utils.util import flatten_dict -from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, MIN, MAX, \ - CheckpointManager as CommonCheckpointManager +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointStrategy, + MIN, + MAX, + CheckpointManager as CommonCheckpointManager, + _TrackedCheckpoint, +) logger = logging.getLogger(__name__) @@ -77,12 +79,31 @@ def __repr__(self): class CheckpointManager(CommonCheckpointManager): + """Initializes a new CheckpointManager. + + `newest_persistent_checkpoint` and `newest_memory_checkpoint` are + initialized to Checkpoint objects with values of None. + + Args: + keep_checkpoints_num: Keep at least this many checkpoints. + checkpoint_score_attr: Attribute to use to determine which + checkpoints to keep. + delete_fn: Function that deletes checkpoints. Must be + idempotent. + """ + def __init__( self, keep_checkpoints_num: int, checkpoint_score_attr: str, delete_fn: Callable[[str], None], ): + if keep_checkpoints_num == 0: + raise RuntimeError( + "If checkpointing is enabled, Ray Tune requires `keep_checkpoints_num` " + "to be None or a number greater than 0" + ) + checkpoint_score_desc = checkpoint_score_attr.startswith("min-") if checkpoint_score_desc: checkpoint_score_attr = checkpoint_score_attr[4:] @@ -92,148 +113,72 @@ def __init__( checkpoint_strategy = CheckpointStrategy( num_to_keep=keep_checkpoints_num, checkpoint_score_attribute=checkpoint_score_attr, - checkpoint_score_order=MIN if checkpoint_score_desc else MAX + checkpoint_score_order=MIN if checkpoint_score_desc else MAX, ) - def - + self._delete_fn = delete_fn -class CheckpointManagerLegacy: - """Manages checkpoints on the driver for a trial.""" + super().__init__(checkpoint_strategy=checkpoint_strategy) - def __init__( - self, - keep_checkpoints_num: int, - checkpoint_score_attr: str, - delete_fn: Callable[[str], None], - ): - """Initializes a new CheckpointManager. + def on_checkpoint(self, checkpoint: _TrackedCheckpoint): + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: + self._replace_latest_memory_checkpoint(checkpoint) + else: + assert checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + assert ( + self._checkpoint_strategy.num_to_keep is None + or self._checkpoint_strategy.num_to_keep > 0 + ) + self._decide_what_to_do_with_checkpoint(checkpoint) - `newest_persistent_checkpoint` and `newest_memory_checkpoint` are - initialized to Checkpoint objects with values of None. + def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + assert persisted_checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + # Ray Tune always keeps track of the latest persisted checkpoint + self._replace_latest_persisted_checkpoint( + persisted_checkpoint=persisted_checkpoint + ) + logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") - Args: - keep_checkpoints_num: Keep at least this many checkpoints. - checkpoint_score_attr: Attribute to use to determine which - checkpoints to keep. - delete_fn: Function that deletes checkpoints. Must be - idempotent. - """ - self.keep_checkpoints_num = keep_checkpoints_num or float("inf") - assert ( - self.keep_checkpoints_num > 0 - ), "keep_checkpoints_num must be greater than 0." - self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-") - if self._checkpoint_score_desc: - self._checkpoint_score_attr = checkpoint_score_attr[4:] + def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + if persisted_checkpoint == self._latest_persisted_checkpoint: + self._checkpoints_to_clean_up.add(persisted_checkpoint) else: - self._checkpoint_score_attr = checkpoint_score_attr + persisted_checkpoint.delete() - self.delete = delete_fn - self.newest_persistent_checkpoint = _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, None + # Tune-specific properties + + @property + def newest_persistent_checkpoint(self): + return self._latest_persisted_checkpoint or _TrackedCheckpoint( + checkpoint_dir_or_data=None, + checkpoint_id=0, + storage_mode=_TrackedCheckpoint.PERSISTENT, ) - self._newest_memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, None) - self._best_checkpoints = [] - self._membership = set() - self._cur_order = 0 @property def newest_checkpoint(self): """Returns the newest checkpoint (based on training iteration).""" newest_checkpoint = max( [self.newest_persistent_checkpoint, self.newest_memory_checkpoint], - key=lambda c: c.order, + key=lambda c: c.checkpoint_id, ) return newest_checkpoint @property def newest_memory_checkpoint(self): - return self._newest_memory_checkpoint - - def replace_newest_memory_checkpoint(self, new_checkpoint): - # Forcibly remove the memory checkpoint - del self._newest_memory_checkpoint - # Apparently avoids memory leaks on k8s/k3s/pods - gc.collect() - self._newest_memory_checkpoint = new_checkpoint - - def on_checkpoint(self, checkpoint: _TuneCheckpoint): - """Starts tracking checkpoint metadata on checkpoint. - - Checkpoints get assigned with an `order` as they come in. - The order is monotonically increasing. - - Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes - previous checkpoint as long as it isn't one of the best ones. Also - deletes the worst checkpoint if at capacity. - - Args: - checkpoint: Trial state checkpoint. - """ - self._cur_order += 1 - checkpoint.order = self._cur_order - - if checkpoint.storage == _TuneCheckpoint.MEMORY: - self.replace_newest_memory_checkpoint(checkpoint) - return - - old_checkpoint = self.newest_persistent_checkpoint - - if old_checkpoint.value == checkpoint.value: - # Overwrite the order of the checkpoint. - old_checkpoint.order = checkpoint.order - return - - self.newest_persistent_checkpoint = checkpoint - - # Remove the old checkpoint if it isn't one of the best ones. - if old_checkpoint.value and old_checkpoint not in self._membership: - self.delete(old_checkpoint) - - try: - queue_item = QueueItem(self._priority(checkpoint), checkpoint) - except KeyError: - logger.error( - "Result dict has no key: {}. " - "checkpoint_score_attr must be set to a key in the " - "result dict.".format(self._checkpoint_score_attr) - ) - return - - if len(self._best_checkpoints) < self.keep_checkpoints_num: - heapq.heappush(self._best_checkpoints, queue_item) - self._membership.add(checkpoint) - elif queue_item.priority >= self._best_checkpoints[0].priority: - worst = heapq.heappushpop(self._best_checkpoints, queue_item).value - self._membership.add(checkpoint) - if worst in self._membership: - self._membership.remove(worst) - # Don't delete the newest checkpoint. It will be deleted on the - # next on_checkpoint() call since it isn't in self._membership. - if worst.value != checkpoint.value: - self.delete(worst) + return self._latest_memory_checkpoint def best_checkpoints(self): """Returns best PERSISTENT checkpoints, sorted by score.""" - checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority) + checkpoints = sorted(self._top_persisted_checkpoints, key=lambda c: c.priority) return [queue_item.value for queue_item in checkpoints] - def _priority(self, checkpoint): - result = flatten_dict(checkpoint.result) - priority = result[self._checkpoint_score_attr] - return -priority if self._checkpoint_score_desc else priority - def __getstate__(self): - state = self.__dict__.copy() - # Avoid serializing the memory checkpoint. - state["_newest_memory_checkpoint"] = _TuneCheckpoint( - _TuneCheckpoint.MEMORY, None - ) - # Avoid serializing lambda since it may capture cyclical dependencies. - state.pop("delete") + state = super().__getstate__() + # Avoid serializing delete fn as it may contain cyclical dependencies + state.pop("_delete_fn", None) return state def __setstate__(self, state): - self.__dict__.update(state) - self.delete = None + state["_delete_fn"] = None + super().__setstate__(state) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index bbbd14820f59..74a190594e8d 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -323,7 +323,7 @@ def testPauseResumeCheckpointCount(self): trial = Trial("__fake", keep_checkpoints_num=2) trial.init_logdir() - trial.checkpoint_manager.delete = lambda cp: shutil.rmtree(cp.value) + trial.checkpoint_manager._delete_fn = lambda cp: shutil.rmtree(cp.value) def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir( @@ -377,7 +377,7 @@ def get_checkpoint_dirs(trial: Trial): runner.resume() trial = runner.get_trials()[0] - trial.checkpoint_manager.delete = lambda cp: shutil.rmtree(cp.value) + trial.checkpoint_manager._delete_fn = lambda cp: shutil.rmtree(cp.value) # Write fourth checkpoint result = write_checkpoint(trial, 4) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f793714262f2..f5f2c2c140aa 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -372,7 +372,7 @@ def __init__( self.checkpoint_manager = CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, - CheckpointDeleter(self._trainable_name(), self.runner), + delete_fn=CheckpointDeleter(self._trainable_name(), self.runner), ) # Restoration fields @@ -564,7 +564,7 @@ def set_runner(self, runner): self._default_result_or_future = runner.get_auto_filled_metrics.remote( debug_metrics_only=True ) - self.checkpoint_manager.delete = CheckpointDeleter( + self.checkpoint_manager._delete_fn = CheckpointDeleter( self._trainable_name(), runner ) # No need to invalidate state cache: runner is not stored in json diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 63e6f7afc9d6..7ef658f6dd40 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,14 +1,18 @@ +import gc import heapq import logging import numbers +import os +import shutil from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Callable, Tuple import ray from ray.train.constants import TIMESTAMP, TUNE_INSTALLED from ray.tune.result import NODE_IP from ray.util import PublicAPI +from ray.util.ml_utils.util import is_nan if TUNE_INSTALLED: pass @@ -25,13 +29,15 @@ class _TrackedCheckpoint: MEMORY = "memory" PERSISTENT = "persistent" - def __init__(self, - checkpoint_dir_or_data: Union[str, Path, Dict, ray.ObjectRef], - checkpoint_id: int, - storage_mode: str, - result: Optional[Dict] = None, - node_ip: Optional[str] = None, - ): + def __init__( + self, + checkpoint_dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], + checkpoint_id: int, + storage_mode: str, + result: Optional[Dict] = None, + node_ip: Optional[str] = None, + delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, + ): self.checkpoint_dir_or_data = checkpoint_dir_or_data self.checkpoint_id = checkpoint_id self.storage_mode = storage_mode @@ -40,15 +46,34 @@ def __init__(self, self.result = result or {} self.node_ip = node_ip or self.result.get(NODE_IP, None) - def commit(self): + self._delete_fn = delete_fn or _default_delete_fn + + def commit(self, path: Optional[Path] = None): pass def delete(self): - pass + self._delete_fn(self) + + +def _default_delete_fn(checkpoint: _TrackedCheckpoint): + if checkpoint.storage_mode != _TrackedCheckpoint.PERSISTENT: + return + + if os.path.isfile(checkpoint.checkpoint_dir_or_data): + os.remove(checkpoint.checkpoint_dir_or_data) + elif os.path.isdir(checkpoint.checkpoint_dir_or_data): + shutil.rmtree(checkpoint.checkpoint_dir_or_data) + else: + logger.warning( + f"Could not delete checkpoint {checkpoint} from disk as it is " + f"neither file not directory. Path: {checkpoint.checkpoint_dir_or_data}." + ) class _HeapCheckpointWrapper: - def __init__(self, priority: numbers.Number, tracked_checkpoint: _TrackedCheckpoint): + def __init__( + self, priority: numbers.Number, tracked_checkpoint: _TrackedCheckpoint + ): self.priority = priority self.tracked_checkpoint = tracked_checkpoint @@ -105,65 +130,65 @@ def __post_init__(self): class CheckpointManager: - def __init__(self, checkpoint_strategy: CheckpointStrategy, latest_checkpoint_id: int = 0): - self._latest_checkpoint_id = latest_checkpoint_id # Todo (krfricke): Review if needed + def __init__( + self, checkpoint_strategy: CheckpointStrategy, latest_checkpoint_id: int = 0 + ): self._checkpoint_strategy = checkpoint_strategy self.latest_checkpoint = None # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = 0 + self._latest_checkpoint_id = latest_checkpoint_id # Used for keeping top K checkpoints. self._top_persisted_checkpoints = [] # Best checkpoint altogether. # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint_wrapped: Optional[_HeapCheckpointWrapper] = None - - self._last_persisted_checkpoint: Optional[_TrackedCheckpoint] = None - - self._last_memory_checkpoint: Optional[_TrackedCheckpoint] = None - - # Do we need this at all? - @property - def _best_persisted_checkpoint(self) -> _TrackedCheckpoint: - return self._best_persisted_checkpoint_wrapped.tracked_checkpoint - - def decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): - if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: - self._last_memory_checkpoint = checkpoint - - second_to_last_persisted_checkpoint = self._last_persisted_checkpoint - self._last_persisted_checkpoint = checkpoint - - num_to_keep = self._checkpoint_strategy.num_to_keep - - if num_to_keep == 0: - if second_to_last_persisted_checkpoint: - second_to_last_persisted_checkpoint.delete() - - # Checkpoints should not be persisted to disk. - return - + self._best_persisted_checkpoint: Optional[_TrackedCheckpoint] = None + self._latest_persisted_checkpoint: Optional[_TrackedCheckpoint] = None + self._latest_memory_checkpoint: Optional[_TrackedCheckpoint] = None + + # Checkpoints that are not immediately removed + self._checkpoints_to_clean_up = set() + + def _replace_latest_memory_checkpoint(self, memory_checkpoint: _TrackedCheckpoint): + assert memory_checkpoint.storage_mode == _TrackedCheckpoint.MEMORY + self._latest_memory_checkpoint = memory_checkpoint + # Avoid memory leaks on k8s pods + gc.collect() + + def _replace_latest_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): + second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint + self._latest_persisted_checkpoint = persisted_checkpoint + + if self._checkpoint_strategy.num_to_keep == 0: + self._delete_persisted_checkpoint(second_to_latest_persisted_checkpoint) + + def _maybe_replace_best_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): + if self._best_persisted_checkpoint is None: + self._best_persisted_checkpoint = persisted_checkpoint + else: + old_score = self._get_checkpoint_score(self._best_persisted_checkpoint) + candidate_score = self._get_checkpoint_score(persisted_checkpoint) + if candidate_score >= old_score: + self._best_persisted_checkpoint = persisted_checkpoint + + def _get_checkpoint_score( + self, checkpoint: _TrackedCheckpoint + ) -> Tuple[bool, numbers.Number, int]: checkpoint_score_attribute = ( self._checkpoint_strategy.checkpoint_score_attribute ) - - if checkpoint_score_attribute not in checkpoint: - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " - f"Include this attribute in the call to " - f"train.save_checkpoint." - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order if checkpoint_score_order == MIN: - order_factor = 1. + order_factor = 1.0 else: - order_factor = -1. + order_factor = -1.0 checkpoint_score = order_factor * checkpoint.result[checkpoint_score_attribute] @@ -176,36 +201,69 @@ def decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): f"This attribute must be numerical." ) - wrapped_checkpoint = _HeapCheckpointWrapper(priority=checkpoint_score, tracked_checkpoint=checkpoint) + return ( + not is_nan(checkpoint_score), + checkpoint_score if not is_nan(checkpoint_score) else 0, + checkpoint.checkpoint_id, + ) - if num_to_keep is None: - # Keep all checkpoints. - checkpoint.commit() - # Todo: track as latest available persisted checkpoint - pass - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep this checkpoint - checkpoint.commit() - heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) + def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): + checkpoint_score = self._get_checkpoint_score(checkpoint) + wrapped_checkpoint = _HeapCheckpointWrapper( + priority=checkpoint_score, tracked_checkpoint=checkpoint + ) + + if self._checkpoint_strategy.num_to_keep is None: + # Keep all checkpoints + checkpoint.commit(path=self._get_next_checkpoint_path()) + self._replace_latest_persisted_checkpoint(checkpoint) elif ( - wrapped_checkpoint.priority > self._top_persisted_checkpoints[0].priority + len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep ): - # Write checkpoint to disk if not yet persisted - checkpoint.commit() - worst_checkpoint = heapq.heappushpop(self._top_persisted_checkpoints, wrapped_checkpoint).tracked_checkpoint - worst_checkpoint.delete() + # Heap is not full yet, so keep this checkpoint + checkpoint.commit(path=self._get_next_checkpoint_path()) + heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) + self._replace_latest_persisted_checkpoint(checkpoint) + elif wrapped_checkpoint.priority > self._top_persisted_checkpoints[0].priority: + # Priority is higher than current worst checkpoint, so replace worst + checkpoint.commit(path=self._get_next_checkpoint_path()) + worst_checkpoint = heapq.heappushpop( + self._top_persisted_checkpoints, wrapped_checkpoint + ).tracked_checkpoint + self._delete_persisted_checkpoint(worst_checkpoint) logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") else: # If the latest checkpoint has the same or lower priority, skip it. - # Todo: fix this - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." - ) + self._skip_persisted_checkpoint(checkpoint) - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or wrapped_checkpoint.priority > self._best_persisted_checkpoint_wrapped.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint_wrapped = checkpoint + self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) + self._cleanup_checkpoints() + + def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + if persisted_checkpoint == self._latest_persisted_checkpoint: + self._checkpoints_to_clean_up.add(persisted_checkpoint) + else: + persisted_checkpoint.delete() + + def _cleanup_checkpoints(self): + for checkpoint in self._checkpoints_to_clean_up: + self._delete_persisted_checkpoint(persisted_checkpoint=checkpoint) + + def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") + + def _get_next_checkpoint_path(self) -> Optional[Path]: + return None + + def __getstate__(self): + state = self.__dict__.copy() + # Avoid serializing the memory checkpoint. + state["_newest_memory_checkpoint"] = _TrackedCheckpoint( + checkpoint_dir_or_data=None, + checkpoint_id=0, + storage_mode=_TrackedCheckpoint.MEMORY, + ) + return state + + def __setstate__(self, state): + self.__dict__.update(state) From 50525c8379807ca96a98e29f31a1dc65de3f68f4 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 3 May 2022 14:03:05 +0100 Subject: [PATCH 05/81] Remove `_TuneCheckpoint` usages --- python/ray/tune/callback.py | 4 +- python/ray/tune/checkpoint_manager.py | 69 +------------ python/ray/tune/ray_trial_executor.py | 23 +++-- python/ray/tune/schedulers/pbt.py | 7 +- python/ray/tune/syncer.py | 21 ++-- .../ray/tune/tests/test_checkpoint_manager.py | 96 ++++++++++++++----- .../ray/tune/tests/test_ray_trial_executor.py | 12 ++- python/ray/tune/tests/test_trial_runner_2.py | 8 +- .../tune/tests/test_trial_runner_callbacks.py | 8 +- python/ray/tune/tests/test_trial_scheduler.py | 19 +++- .../tune/tests/test_trial_scheduler_pbt.py | 9 +- .../test_trial_scheduler_resource_changing.py | 9 +- python/ray/tune/trial.py | 25 +++-- python/ray/tune/trial_executor.py | 9 +- python/ray/tune/trial_runner.py | 9 +- .../ray/util/ml_utils/checkpoint_manager.py | 6 +- 16 files changed, 177 insertions(+), 157 deletions(-) diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index 7b62546f4da0..7b2989c03e99 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -2,8 +2,8 @@ from abc import ABCMeta import warnings -from ray.tune.checkpoint_manager import _TuneCheckpoint from ray.util.annotations import PublicAPI +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -245,7 +245,7 @@ def on_checkpoint( iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: _TuneCheckpoint, + checkpoint: _TrackedCheckpoint, **info, ): """Called after a trial saved a checkpoint with Tune. diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 3345d7b90d4d..7864826efcb3 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -1,8 +1,7 @@ # coding: utf-8 import logging -from typing import Any, Callable, Optional +from typing import Callable -from ray.tune.result import NODE_IP from ray.util.ml_utils.checkpoint_manager import ( CheckpointStrategy, MIN, @@ -14,70 +13,6 @@ logger = logging.getLogger(__name__) -class _TuneCheckpoint: - """Describes a checkpoint of trial state. - - Checkpoint may be saved in different storage. - - Attributes: - storage: Storage type. - value: If storage==MEMORY, it is a Python object. - If storage==PERSISTENT, it is a path to persistent storage, - or a future that will be resolved to such a path. - """ - - MEMORY = "memory" - PERSISTENT = "persistent" - - def __init__( - self, - storage: str, - value: Any, - result: Optional[dict] = None, - node_ip: Optional[str] = None, - ): - self.storage = storage - self.value = value - self.result = result or {} - self.node_ip = node_ip or self.result.get(NODE_IP, None) - # The logical order of checkpoints (both in memory and persistent) - # The more recent checkpoints have larger order. - # The most recent checkpoint is used to restore the trial. - self.order = 0 - - @staticmethod - def from_object(value=None): - """Creates a checkpoint from a Python object.""" - return _TuneCheckpoint(_TuneCheckpoint.MEMORY, value) - - @property - def is_ready(self): - """Returns whether the checkpoint is ready to be used for restoration. - - A PERSISTENT checkpoint is considered ready once its value is resolved - to an actual path. MEMORY checkpoints are always considered ready since - they are transient. - """ - if self.storage == _TuneCheckpoint.PERSISTENT: - return isinstance(self.value, str) - return self.storage == _TuneCheckpoint.MEMORY - - def __repr__(self): - return f"Checkpoint({self.storage}, {self.value})" - - -class QueueItem: - def __init__(self, priority, value): - self.priority = priority - self.value = value - - def __lt__(self, other): - return self.priority < other.priority - - def __repr__(self): - return f"QueueItem({repr(self.value)})" - - class CheckpointManager(CommonCheckpointManager): """Initializes a new CheckpointManager. @@ -171,7 +106,7 @@ def newest_memory_checkpoint(self): def best_checkpoints(self): """Returns best PERSISTENT checkpoints, sorted by score.""" checkpoints = sorted(self._top_persisted_checkpoints, key=lambda c: c.priority) - return [queue_item.value for queue_item in checkpoints] + return [wrapped.tracked_checkpoint for wrapped in checkpoints] def __getstate__(self): state = super().__getstate__() diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 38d6dea1328e..e66b778a7691 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -32,12 +32,13 @@ from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.utils.placement_groups import PlacementGroupManager, get_tune_pg_prefix from ray.tune.utils.trainable import TrainableUtil -from ray.tune.trial import Trial, _TuneCheckpoint, Location, TrialInfo +from ray.tune.trial import Trial, Location, TrialInfo from ray.tune.trial_executor import TrialExecutor from ray.tune.utils import warn_if_slow from ray.tune.utils.resource_updater import ResourceUpdater from ray.util import log_once from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint from ray.util.placement_group import remove_placement_group, PlacementGroup logger = logging.getLogger(__name__) @@ -671,9 +672,9 @@ def force_reconcilation_on_next_step_end(self) -> None: def save( self, trial: Trial, - storage: str = _TuneCheckpoint.PERSISTENT, + storage: str = _TrackedCheckpoint.PERSISTENT, result: Optional[Dict] = None, - ) -> _TuneCheckpoint: + ) -> _TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. Args: @@ -689,13 +690,17 @@ def save( logger.debug(f"saving trial {trial}") result = result or trial.last_result with self._change_working_directory(trial): - if storage == _TuneCheckpoint.MEMORY: + if storage == _TrackedCheckpoint.MEMORY: value = trial.runner.save_to_object.remote() - checkpoint = _TuneCheckpoint(storage, value, result) + checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=value, storage_mode=storage, result=result + ) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() - checkpoint = _TuneCheckpoint(storage, value, result) + checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=value, storage_mode=storage, result=result + ) trial.saving_to = checkpoint self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial) return checkpoint @@ -712,15 +717,15 @@ def restore(self, trial: Trial) -> None: ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint - if checkpoint.value is None: + if checkpoint.checkpoint_dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial) ) - value = checkpoint.value + value = checkpoint.checkpoint_dir_or_data node_ip = checkpoint.node_ip - if checkpoint.storage == _TuneCheckpoint.MEMORY: + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index a43276926eb5..e341a2b8ea31 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -16,8 +16,9 @@ from ray.tune.sample import Domain, Function from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial from ray.util.debug import log_once +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint logger = logging.getLogger(__name__) @@ -528,7 +529,7 @@ def _checkpoint_or_exploit( state.last_checkpoint = trial.checkpoint else: state.last_checkpoint = trial_executor.save( - trial, _TuneCheckpoint.MEMORY, result=state.last_result + trial, _TrackedCheckpoint.MEMORY, result=state.last_result ) self._num_checkpoints += 1 else: @@ -872,7 +873,7 @@ def on_trial_result( ) checkpoint = trial_runner.trial_executor.save( - trial, _TuneCheckpoint.MEMORY, result=result + trial, _TrackedCheckpoint.MEMORY, result=result ) new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config) diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index ab6d05663bb0..a24b63700d45 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -24,7 +24,6 @@ from ray.ml.utils.remote_storage import get_fs_and_path, fs_hint from ray.tune import TuneError from ray.tune.callback import Callback -from ray.tune.checkpoint_manager import _TuneCheckpoint from ray.tune.result import NODE_IP from ray.util import get_node_ip_address from ray.util.debug import log_once @@ -38,6 +37,7 @@ RemoteTaskClient, ) from ray.util.annotations import PublicAPI +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -522,11 +522,8 @@ def _create_trial_syncer(self, trial: "Trial"): trial.logdir, remote_dir=trial.logdir, sync_function=self._sync_function ) - def _remove_trial_syncer(self, trial: "Trial"): - self._syncers.pop(trial, None) - - def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TuneCheckpoint): - if checkpoint.storage == _TuneCheckpoint.MEMORY: + def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint): + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: return trial_syncer = self._get_trial_syncer(trial) @@ -565,7 +562,7 @@ def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TuneCheckpoint): # shouldn't track it with the checkpoint_manager. raise e if not trial.uses_cloud_checkpointing: - if not os.path.exists(checkpoint.value): + if not os.path.exists(checkpoint.checkpoint_dir_or_data): raise TuneError( "Trial {}: Checkpoint path {} not " "found after successful sync down. " @@ -575,7 +572,9 @@ def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TuneCheckpoint): "You'll need to use cloud-checkpointing " "if that's the case, see instructions " "here: {} .".format( - trial, checkpoint.value, CLOUD_CHECKPOINTING_URL + trial, + checkpoint.checkpoint_dir_or_data, + CLOUD_CHECKPOINTING_URL, ) ) @@ -605,17 +604,15 @@ def on_trial_complete( else: trainable_ip = ray.get(trial.runner.get_current_ip.remote()) trial_syncer.set_worker_ip(trainable_ip) - # Always sync down when trial completed - trial_syncer.sync_down() + trial_syncer.sync_down_if_needed() trial_syncer.close() - self._remove_trial_syncer(trial) def on_checkpoint( self, iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: _TuneCheckpoint, + checkpoint: _TrackedCheckpoint, **info, ): self._sync_trial_checkpoint(trial, checkpoint) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index b4a9267b63d9..824d0ec71471 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -7,7 +7,8 @@ from unittest.mock import patch from ray.tune.result import TRAINING_ITERATION -from ray.tune.checkpoint_manager import _TuneCheckpoint, CheckpointManager, logger +from ray.tune.checkpoint_manager import CheckpointManager, logger +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint class CheckpointManagerTest(unittest.TestCase): @@ -20,12 +21,16 @@ def checkpoint_manager(self, keep_checkpoints_num): def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - memory_checkpoint = _TuneCheckpoint( - _TuneCheckpoint.MEMORY, {0}, self.mock_result(0, 0) + memory_checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data={0}, + storage_mode=_TrackedCheckpoint.MEMORY, + result=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) - persistent_checkpoint = _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1, 1) + persistent_checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data={1}, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(1, 1), ) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual( @@ -40,7 +45,11 @@ def testOnCheckpointOrdered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i)) + _TrackedCheckpoint( + checkpoint_dir_or_data={i}, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(i, i), + ) for i in range(3) ] @@ -66,7 +75,11 @@ def testOnCheckpointUnordered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i)) + _TrackedCheckpoint( + checkpoint_dir_or_data={i}, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(i, i), + ) for i in range(3, -1, -1) ] @@ -90,7 +103,11 @@ def testBestCheckpoints(self): """ keep_checkpoints_num = 4 checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i)) + _TrackedCheckpoint( + checkpoint_dir_or_data=i, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(i, i), + ) for i in range(8) ] @@ -111,11 +128,19 @@ def testBestCheckpointsWithNan(self): """ keep_checkpoints_num = 2 checkpoints = [ - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i) + _TrackedCheckpoint( + checkpoint_dir_or_data=None, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(float("nan"), i), ) for i in range(2) - ] + [_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))] + ] + [ + _TrackedCheckpoint( + checkpoint_dir_or_data=3, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(0, 3), + ) + ] for permutation in itertools.permutations(checkpoints): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) @@ -135,8 +160,10 @@ def testBestCheckpointsOnlyNan(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, i, self.mock_result(float("nan"), i) + _TrackedCheckpoint( + checkpoint_dir_or_data=i, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(float("nan"), i), ) for i in range(4) ] @@ -157,7 +184,12 @@ def testOnCheckpointUnavailableAttribute(self): """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - no_attr_checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 0, {}) + no_attr_checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=0, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result={}, + ) + with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() @@ -168,8 +200,16 @@ def testOnCheckpointUnavailableAttribute(self): def testOnMemoryCheckpoint(self): checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)), - _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)), + _TrackedCheckpoint( + checkpoint_dir_or_data=0, + storage_mode=_TrackedCheckpoint.MEMORY, + result=self.mock_result(0, 0), + ), + _TrackedCheckpoint( + checkpoint_dir_or_data=0, + storage_mode=_TrackedCheckpoint.MEMORY, + result=self.mock_result(0, 0), + ), ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) checkpoint_manager.on_checkpoint(checkpoints[0]) @@ -192,17 +232,25 @@ def testSameCheckpoint(self): tmpfiles.append(tmpfile) checkpoints = [ - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5, 5) + _TrackedCheckpoint( + checkpoint_dir_or_data=tmpfiles[0], + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(5, 5), ), - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10, 10) + _TrackedCheckpoint( + checkpoint_dir_or_data=tmpfiles[1], + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(10, 10), ), - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0, 0) + _TrackedCheckpoint( + checkpoint_dir_or_data=tmpfiles[2], + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(0, 0), ), - _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20, 20) + _TrackedCheckpoint( + checkpoint_dir_or_data=tmpfiles[1], + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=self.mock_result(20, 20), ), ] for checkpoint in checkpoints: diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index f9e78b125c3e..3e69793cbca9 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -18,12 +18,14 @@ from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID from ray.tune.suggest import BasicVariantGenerator -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial from ray.tune.resources import Resources from ray.cluster_utils import Cluster from ray.tune.utils.placement_groups import PlacementGroupFactory from unittest.mock import patch +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint + class TrialExecutorInsufficientResourcesTest(unittest.TestCase): def setUp(self): @@ -118,9 +120,9 @@ def _simulate_getting_result(self, trial): trial.update_last_result(training_result) def _simulate_saving(self, trial): - checkpoint = self.trial_executor.save(trial, _TuneCheckpoint.PERSISTENT) + checkpoint = self.trial_executor.save(trial, _TrackedCheckpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) - self.assertEqual(trial.checkpoint.value, None) + self.assertEqual(trial.checkpoint.checkpoint_dir_or_data, None) event = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False ) @@ -187,7 +189,7 @@ def testSavePauseResumeErrorRestore(self): # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) - self.assertEqual(trial.checkpoint.storage, _TuneCheckpoint.MEMORY) + self.assertEqual(trial.checkpoint.storage_mode, _TrackedCheckpoint.MEMORY) # Resume self._simulate_starting_trial(trial) @@ -374,7 +376,7 @@ def generate_trials(spec, name): def process_trial_save(self, trial, checkpoint_value): """Simulates trial runner save.""" checkpoint = trial.saving_to - checkpoint.value = checkpoint_value + checkpoint.checkpoint_dir_or_data = checkpoint_value trial.on_checkpoint(checkpoint) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 74a190594e8d..8832274976bc 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -9,7 +9,6 @@ from ray.rllib import _register_all from ray.tune import TuneError -from ray.tune.checkpoint_manager import _TuneCheckpoint from ray.tune.schedulers import FIFOScheduler from ray.tune.result import DONE from ray.tune.registry import _global_registry, TRAINABLE_CLASS @@ -19,6 +18,7 @@ from ray.tune.suggest import BasicVariantGenerator from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver from ray.tune.utils.trainable import TrainableUtil +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint def create_mock_components(): @@ -333,8 +333,10 @@ def write_checkpoint(trial: Trial, index: int): with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) - tune_cp = _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, checkpoint_dir, result + tune_cp = _TrackedCheckpoint( + checkpoint_dir_or_data=checkpoint_dir, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=result, ) trial.saving_to = tune_cp trial.on_checkpoint(tune_cp) diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index cae78aeb0476..c9414f5cbeec 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -10,7 +10,6 @@ import ray from ray import tune from ray.rllib import _register_all -from ray.tune.checkpoint_manager import _TuneCheckpoint from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback from ray.tune.ray_trial_executor import ( ExecutorEvent, @@ -26,6 +25,7 @@ from ray.tune import Callback from ray.tune.utils.callback import create_default_callbacks from ray.tune.experiment import Experiment +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint class TestCallback(Callback): @@ -150,8 +150,10 @@ def testCallbackSteps(self): self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. - cp = _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, value=ray.put(1), result={TRAINING_ITERATION: 0} + cp = _TrackedCheckpoint( + checkpoint_dir_or_data=ray.put(1), + storage_mode=_TrackedCheckpoint.PERSISTENT, + result={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index df3b497626d4..ddb0c9781929 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -28,11 +28,12 @@ from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay from ray.tune.suggest._mock import _MockSearcher from ray.tune.suggest.suggestion import ConcurrencyLimiter -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial from ray.tune.trial_executor import TrialExecutor from ray.tune.resources import Resources from ray.rllib import _register_all +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint _register_all() @@ -248,8 +249,12 @@ def stop_trial(self, trial, error=False, error_msg=None): def restore(self, trial, checkpoint=None, block=False): pass - def save(self, trial, type=_TuneCheckpoint.PERSISTENT, result=None): - return _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, trial.trainable_name, result) + def save(self, trial, type=_TrackedCheckpoint.PERSISTENT, result=None): + return _TrackedCheckpoint( + checkpoint_dir_or_data=trial.trainable_name, + storage_mode=_TrackedCheckpoint.PERSISTENT, + result=result, + ) def reset_trial(self, trial, new_config, new_experiment_tag): return False @@ -307,7 +312,7 @@ def get_live_trials(self): return {t for t in self.trials if t.status != Trial.TERMINATED} def _pause_trial(self, trial): - self.trial_executor.save(trial, _TuneCheckpoint.MEMORY, None) + self.trial_executor.save(trial, _TrackedCheckpoint.MEMORY, None) trial.status = Trial.PAUSED def _launch_trial(self, trial): @@ -842,7 +847,11 @@ def on_checkpoint(self, checkpoint): @property def checkpoint(self): - return _TuneCheckpoint(_TuneCheckpoint.MEMORY, self.trainable_name, None) + return _TrackedCheckpoint( + checkpoint_dir_or_data=self.trainable_name, + storage_mode=_TrackedCheckpoint.MEMORY, + result=None, + ) class PopulationBasedTestingSuite(unittest.TestCase): diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index e30e4aa13f88..e53c8150e310 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -11,7 +11,7 @@ import ray from ray import tune from ray.tune import Trainable -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import PopulationBasedTraining @@ -20,6 +20,7 @@ # Import psutil after ray so the packaged version is used. import psutil +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint MB = 1024 ** 2 @@ -439,7 +440,11 @@ def testBurnInPeriod(self): class MockTrial(Trial): @property def checkpoint(self): - return _TuneCheckpoint(_TuneCheckpoint.MEMORY, "None", {}) + return _TrackedCheckpoint( + checkpoint_dir_or_data="None", + storage_mode=_TrackedCheckpoint.MEMORY, + result={}, + ) @property def status(self): diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index 46fe88c993a4..ed56bb3ab815 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -2,12 +2,13 @@ from ray.tune import PlacementGroupFactory from ray.tune.schedulers.trial_scheduler import TrialScheduler -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial from ray.tune.schedulers.resource_changing_scheduler import ( ResourceChangingScheduler, DistributeResources, DistributeResourcesToTopJob, ) +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint class MockResourceUpdater: @@ -48,7 +49,11 @@ def get_trials(self): class MockTrial(Trial): @property def checkpoint(self): - return _TuneCheckpoint(_TuneCheckpoint.MEMORY, "None", {}) + return _TrackedCheckpoint( + checkpoint_dir_or_data="None", + storage_mode=_TrackedCheckpoint.MEMORY, + result={}, + ) class TestUniformResourceAllocation(unittest.TestCase): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f5f2c2c140aa..33ab18665118 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -15,7 +15,7 @@ import ray.cloudpickle as cloudpickle from ray.exceptions import RayActorError, RayTaskError from ray.tune import TuneError -from ray.tune.checkpoint_manager import _TuneCheckpoint, CheckpointManager +from ray.tune.checkpoint_manager import CheckpointManager # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not @@ -40,6 +40,7 @@ from ray.tune.utils import date_str, flatten_dict from ray.util.annotations import DeveloperAPI from ray._private.utils import binary_to_hex, hex_to_binary +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint DEBUG_PRINT_INTERVAL = 5 logger = logging.getLogger(__name__) @@ -99,7 +100,7 @@ def __init__(self, trial_id, runner): self.trial_id = trial_id self.runner = runner - def __call__(self, checkpoint: _TuneCheckpoint): + def __call__(self, checkpoint: _TrackedCheckpoint): """Requests checkpoint deletion asynchronously. Args: @@ -108,8 +109,11 @@ def __call__(self, checkpoint: _TuneCheckpoint): if not self.runner: return - if checkpoint.storage == _TuneCheckpoint.PERSISTENT and checkpoint.value: - checkpoint_path = checkpoint.value + if ( + checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + and checkpoint.checkpoint_dir_or_data + ): + checkpoint_path = checkpoint.checkpoint_dir_or_data logger.debug( "Trial %s: Deleting checkpoint %s", self.trial_id, checkpoint_path @@ -462,8 +466,11 @@ def checkpoint(self): checkpoint = self.checkpoint_manager.newest_persistent_checkpoint else: checkpoint = self.checkpoint_manager.newest_checkpoint - if checkpoint.value is None: - checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, self.restore_path) + if checkpoint.checkpoint_dir_or_data is None: + checkpoint = _TrackedCheckpoint( + checkpoint_dir_or_data=self.restore_path, + storage_mode=_TrackedCheckpoint.PERSISTENT, + ) return checkpoint @classmethod @@ -641,14 +648,14 @@ def should_checkpoint(self): ) def has_checkpoint(self): - return self.checkpoint.value is not None + return self.checkpoint.checkpoint_dir_or_data is not None def clear_checkpoint(self): - self.checkpoint.value = None + self.checkpoint.checkpoint_dir_or_data = None self.restoring_from = None self.invalidate_json_state() - def on_checkpoint(self, checkpoint: _TuneCheckpoint): + def on_checkpoint(self, checkpoint: _TrackedCheckpoint): """Hook for handling checkpoints taken by the Trainable. Args: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 5bea2bbb8ad0..e391864a81dc 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -6,7 +6,8 @@ from ray.exceptions import RayTaskError from ray.tune import TuneError from ray.util.annotations import DeveloperAPI -from ray.tune.trial import Trial, _TuneCheckpoint +from ray.tune.trial import Trial +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint logger = logging.getLogger(__name__) @@ -123,7 +124,7 @@ def pause_trial(self, trial: Trial) -> None: """ assert trial.status == Trial.RUNNING, trial.status try: - self.save(trial, _TuneCheckpoint.MEMORY) + self.save(trial, _TrackedCheckpoint.MEMORY) self.stop_trial(trial) self.set_status(trial, Trial.PAUSED) except Exception: @@ -193,9 +194,9 @@ def restore(self, trial: Trial) -> None: def save( self, trial: Trial, - storage: str = _TuneCheckpoint.PERSISTENT, + storage: str = _TrackedCheckpoint.PERSISTENT, result: Optional[Dict] = None, - ) -> _TuneCheckpoint: + ) -> _TrackedCheckpoint: """Saves training state of this trial to a checkpoint. If result is None, this trial's last result will be used. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 1c199f2d02e7..04b392fbe72d 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -35,13 +35,14 @@ from ray.tune.stopper import NoopStopper, Stopper from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm from ray.tune.syncer import CloudSyncer, get_cloud_syncer, SyncConfig -from ray.tune.trial import _TuneCheckpoint, Trial +from ray.tune.trial import Trial from ray.tune.utils import warn_if_slow, flatten_dict from ray.tune.utils.log import Verbosity, has_verbosity from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder from ray.tune.web_server import TuneServer from ray.util.debug import log_once +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint MAX_DEBUG_TRIALS = 20 @@ -1108,7 +1109,7 @@ def _process_trial_save( logger.debug("Trial %s: Processing trial save.", trial) try: - trial.saving_to.value = checkpoint_value + trial.saving_to.checkpoint_dir_or_data = checkpoint_value self._callbacks.on_checkpoint( iteration=self._iteration, trials=self._trials, @@ -1116,7 +1117,7 @@ def _process_trial_save( checkpoint=trial.saving_to, ) trial.on_checkpoint(trial.saving_to) - if trial.checkpoint.storage != _TuneCheckpoint.MEMORY: + if trial.checkpoint.storage_mode != _TrackedCheckpoint.MEMORY: self.trial_executor.mark_trial_to_checkpoint(trial) except Exception: logger.exception( @@ -1203,7 +1204,7 @@ def _checkpoint_trial_if_needed(self, trial, force=False): if trial.should_checkpoint() or force: # Save trial runtime if possible. if trial.runner: - self.trial_executor.save(trial, storage=_TuneCheckpoint.PERSISTENT) + self.trial_executor.save(trial, storage=_TrackedCheckpoint.PERSISTENT) def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]): """Tries to recover trial. diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 7ef658f6dd40..6373371e6592 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -6,7 +6,7 @@ import shutil from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Union, Callable, Tuple +from typing import Optional, Dict, Union, Callable, Tuple, List import ray from ray.train.constants import TIMESTAMP, TUNE_INSTALLED @@ -32,8 +32,8 @@ class _TrackedCheckpoint: def __init__( self, checkpoint_dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], - checkpoint_id: int, storage_mode: str, + checkpoint_id: int = 0, result: Optional[Dict] = None, node_ip: Optional[str] = None, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, @@ -141,7 +141,7 @@ def __init__( self._latest_checkpoint_id = latest_checkpoint_id # Used for keeping top K checkpoints. - self._top_persisted_checkpoints = [] + self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] # Best checkpoint altogether. # Used for exposing best_checkpoint_path. From 5cdb3047450f15416051b6530a549c9fe7b67914 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 3 May 2022 15:09:29 +0100 Subject: [PATCH 06/81] test_checkpoint_manager.py is passing --- python/ray/tune/checkpoint_manager.py | 17 ++-- .../ray/tune/tests/test_checkpoint_manager.py | 40 +++++--- .../ray/util/ml_utils/checkpoint_manager.py | 94 ++++++++++++------- 3 files changed, 99 insertions(+), 52 deletions(-) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 7864826efcb3..cb57524225a5 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -56,6 +56,11 @@ def __init__( super().__init__(checkpoint_strategy=checkpoint_strategy) def on_checkpoint(self, checkpoint: _TrackedCheckpoint): + checkpoint.checkpoint_id = ( + checkpoint.checkpoint_id or self._latest_checkpoint_id + ) + self._latest_checkpoint_id += 1 + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) else: @@ -68,17 +73,13 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint): def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): assert persisted_checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT - # Ray Tune always keeps track of the latest persisted checkpoint + super()._skip_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) + # Ray Tune always keeps track of the latest persisted checkpoint. + # Note that this checkpoint will be deleted once it is not the + # latest checkpoint anymore self._replace_latest_persisted_checkpoint( persisted_checkpoint=persisted_checkpoint ) - logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") - - def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): - if persisted_checkpoint == self._latest_persisted_checkpoint: - self._checkpoints_to_clean_up.add(persisted_checkpoint) - else: - persisted_checkpoint.delete() # Tune-specific properties diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 824d0ec71471..65392040abf3 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -17,7 +17,11 @@ def mock_result(metric, i): return {"i": metric, TRAINING_ITERATION: i} def checkpoint_manager(self, keep_checkpoints_num): - return CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None) + return CheckpointManager( + keep_checkpoints_num=keep_checkpoints_num, + checkpoint_score_attr="i", + delete_fn=lambda c: None, + ) def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) @@ -53,7 +57,9 @@ def testOnCheckpointOrdered(self): for i in range(3) ] - with patch.object(checkpoint_manager, "delete") as delete_mock: + with patch.object( + checkpoint_manager, "_delete_persisted_checkpoint" + ) as delete_mock: for j in range(3): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 2 else 1 @@ -83,11 +89,17 @@ def testOnCheckpointUnordered(self): for i in range(3, -1, -1) ] - with patch.object(checkpoint_manager, "delete") as delete_mock: + with patch.object( + checkpoint_manager, "_delete_persisted_checkpoint" + ) as delete_mock: for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 - self.assertEqual(delete_mock.call_count, expected_deletes) + self.assertEqual( + delete_mock.call_count, + expected_deletes, + msg=f"Called {delete_mock.call_count} times", + ) self.assertEqual( checkpoint_manager.newest_persistent_checkpoint, checkpoints[j] ) @@ -120,7 +132,7 @@ def testBestCheckpoints(self): best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): - self.assertEqual(best_checkpoints[i].value, i + 4) + self.assertEqual(best_checkpoints[i].checkpoint_dir_or_data, i + 4) def testBestCheckpointsWithNan(self): """ @@ -150,8 +162,8 @@ def testBestCheckpointsWithNan(self): best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) - self.assertEqual(best_checkpoints[0].value, None) - self.assertEqual(best_checkpoints[1].value, 3) + self.assertEqual(best_checkpoints[0].checkpoint_dir_or_data, None) + self.assertEqual(best_checkpoints[1].checkpoint_dir_or_data, 3) def testBestCheckpointsOnlyNan(self): """ @@ -174,8 +186,8 @@ def testBestCheckpointsOnlyNan(self): best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) - self.assertEqual(best_checkpoints[0].value, 2) - self.assertEqual(best_checkpoints[1].value, 3) + self.assertEqual(best_checkpoints[0].checkpoint_dir_or_data, 2) + self.assertEqual(best_checkpoints[1].checkpoint_dir_or_data, 3) def testOnCheckpointUnavailableAttribute(self): """ @@ -190,7 +202,9 @@ def testOnCheckpointUnavailableAttribute(self): result={}, ) - with patch.object(logger, "error") as log_error_mock: + from ray.util.ml_utils.checkpoint_manager import logger as cp_logger + + with patch.object(cp_logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. @@ -221,7 +235,9 @@ def testOnMemoryCheckpoint(self): def testSameCheckpoint(self): checkpoint_manager = CheckpointManager( - 1, "i", delete_fn=lambda c: os.remove(c.value) + keep_checkpoints_num=1, + checkpoint_score_attr="i", + delete_fn=lambda c: os.remove(c.checkpoint_dir_or_data), ) tmpfiles = [] @@ -255,7 +271,7 @@ def testSameCheckpoint(self): ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) - self.assertTrue(os.path.exists(checkpoint.value)) + self.assertTrue(os.path.exists(checkpoint.checkpoint_dir_or_data)) for tmpfile in tmpfiles: if os.path.exists(tmpfile): diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 6373371e6592..c054b0e67667 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -6,19 +6,13 @@ import shutil from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Union, Callable, Tuple, List +from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray -from ray.train.constants import TIMESTAMP, TUNE_INSTALLED -from ray.tune.result import NODE_IP +from ray.tune.result import NODE_IP, TRAINING_ITERATION from ray.util import PublicAPI from ray.util.ml_utils.util import is_nan -if TUNE_INSTALLED: - pass -else: - tune = None - MAX = "max" MIN = "min" @@ -33,7 +27,7 @@ def __init__( self, checkpoint_dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, - checkpoint_id: int = 0, + checkpoint_id: Optional[int] = None, result: Optional[Dict] = None, node_ip: Optional[str] = None, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, @@ -54,26 +48,35 @@ def commit(self, path: Optional[Path] = None): def delete(self): self._delete_fn(self) + def __repr__(self): + if self.storage_mode == _TrackedCheckpoint.MEMORY: + return f"<_TrackedCheckpoint storage='MEMORY' result={self.result}>" + + return ( + f"<_TrackedCheckpoint storage='PERSISTENT' " + f"checkpoint_dir_or_data={self.checkpoint_dir_or_data}>" + ) + def _default_delete_fn(checkpoint: _TrackedCheckpoint): if checkpoint.storage_mode != _TrackedCheckpoint.PERSISTENT: return - if os.path.isfile(checkpoint.checkpoint_dir_or_data): - os.remove(checkpoint.checkpoint_dir_or_data) - elif os.path.isdir(checkpoint.checkpoint_dir_or_data): - shutil.rmtree(checkpoint.checkpoint_dir_or_data) - else: - logger.warning( - f"Could not delete checkpoint {checkpoint} from disk as it is " - f"neither file not directory. Path: {checkpoint.checkpoint_dir_or_data}." - ) + if isinstance(checkpoint.checkpoint_dir_or_data, (str, bytes, os.PathLike)): + if os.path.isfile(checkpoint.checkpoint_dir_or_data): + os.remove(checkpoint.checkpoint_dir_or_data) + return + elif os.path.isdir(checkpoint.checkpoint_dir_or_data): + shutil.rmtree(checkpoint.checkpoint_dir_or_data) + return + logger.warning( + f"Could not delete checkpoint {checkpoint} from disk as it is " + f"neither file not directory. Path: {checkpoint.checkpoint_dir_or_data}." + ) class _HeapCheckpointWrapper: - def __init__( - self, priority: numbers.Number, tracked_checkpoint: _TrackedCheckpoint - ): + def __init__(self, priority: Any, tracked_checkpoint: _TrackedCheckpoint): self.priority = priority self.tracked_checkpoint = tracked_checkpoint @@ -113,7 +116,7 @@ class CheckpointStrategy: """ num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TIMESTAMP + checkpoint_score_attribute: str = TRAINING_ITERATION checkpoint_score_order: str = MAX def __post_init__(self): @@ -165,7 +168,9 @@ def _replace_latest_persisted_checkpoint( self._latest_persisted_checkpoint = persisted_checkpoint if self._checkpoint_strategy.num_to_keep == 0: - self._delete_persisted_checkpoint(second_to_latest_persisted_checkpoint) + self._maybe_delete_persisted_checkpoint( + second_to_latest_persisted_checkpoint + ) def _maybe_replace_best_persisted_checkpoint( self, persisted_checkpoint: _TrackedCheckpoint @@ -184,13 +189,23 @@ def _get_checkpoint_score( checkpoint_score_attribute = ( self._checkpoint_strategy.checkpoint_score_attribute ) + if checkpoint_score_attribute not in checkpoint.result: + logger.error( + f"Result dict has no key: {checkpoint_score_attribute}. " + f"checkpoint_score_attr must be set to a key in the " + f"result dict." + ) + checkpoint_result = float("-inf") + else: + checkpoint_result = checkpoint.result[checkpoint_score_attribute] + checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_order == MIN: + if checkpoint_score_order == MAX: order_factor = 1.0 else: order_factor = -1.0 - checkpoint_score = order_factor * checkpoint.result[checkpoint_score_attribute] + checkpoint_score = order_factor * checkpoint_result if not isinstance(checkpoint_score, numbers.Number): raise ValueError( @@ -224,14 +239,22 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): checkpoint.commit(path=self._get_next_checkpoint_path()) heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) self._replace_latest_persisted_checkpoint(checkpoint) - elif wrapped_checkpoint.priority > self._top_persisted_checkpoints[0].priority: + elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority: # Priority is higher than current worst checkpoint, so replace worst checkpoint.commit(path=self._get_next_checkpoint_path()) worst_checkpoint = heapq.heappushpop( self._top_persisted_checkpoints, wrapped_checkpoint ).tracked_checkpoint - self._delete_persisted_checkpoint(worst_checkpoint) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") + + # Only remove if checkpoint data is different + if ( + worst_checkpoint.checkpoint_dir_or_data + != checkpoint.checkpoint_dir_or_data + ): + self._maybe_delete_persisted_checkpoint(worst_checkpoint) + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") + + self._replace_latest_persisted_checkpoint(checkpoint) else: # If the latest checkpoint has the same or lower priority, skip it. self._skip_persisted_checkpoint(checkpoint) @@ -239,18 +262,25 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) self._cleanup_checkpoints() - def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + def _maybe_delete_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): if persisted_checkpoint == self._latest_persisted_checkpoint: self._checkpoints_to_clean_up.add(persisted_checkpoint) else: - persisted_checkpoint.delete() + self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) + + def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + persisted_checkpoint.delete() + self._checkpoints_to_clean_up.discard(persisted_checkpoint) def _cleanup_checkpoints(self): - for checkpoint in self._checkpoints_to_clean_up: - self._delete_persisted_checkpoint(persisted_checkpoint=checkpoint) + for checkpoint in list(self._checkpoints_to_clean_up): + self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") + self._checkpoints_to_clean_up.add(persisted_checkpoint) def _get_next_checkpoint_path(self) -> Optional[Path]: return None From e8da5b64a83bd9773ccc61a84bcaba8de1875988 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 3 May 2022 15:18:43 +0100 Subject: [PATCH 07/81] Fix trainer cp manager init --- python/ray/train/checkpoint.py | 16 +++++++++++++++- python/ray/tune/tests/test_checkpoint_manager.py | 8 +++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 1e183762a756..8a6fdbf9a520 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -3,10 +3,11 @@ from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID, TIMESTAMP from ray.train.constants import TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR from ray.train.session import TrainingResult from ray.train.utils import construct_path +from ray.tune.result import TRAINING_ITERATION as TUNE_TRAINING_ITERATION from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, _TrackedCheckpoint, @@ -157,6 +158,19 @@ def _get_next_checkpoint_path(self) -> Optional[Path]: checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) return self.latest_checkpoint_dir.joinpath(checkpoint_file) + def on_start_training( + self, + checkpoint_strategy: CheckpointStrategy, + run_dir: str, + latest_checkpoint_id: int, + ): + if checkpoint_strategy.checkpoint_score_attribute == TUNE_TRAINING_ITERATION: + checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP + + self._checkpoint_strategy = checkpoint_strategy + self.run_dir = run_dir + self._latest_checkpoint_id = latest_checkpoint_id + # Train-specific attributes @property def latest_checkpoint_dir(self) -> Optional[Path]: diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 65392040abf3..bd82b6216acb 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -7,8 +7,8 @@ from unittest.mock import patch from ray.tune.result import TRAINING_ITERATION -from ray.tune.checkpoint_manager import CheckpointManager, logger -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.tune.checkpoint_manager import CheckpointManager +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, logger class CheckpointManagerTest(unittest.TestCase): @@ -202,9 +202,7 @@ def testOnCheckpointUnavailableAttribute(self): result={}, ) - from ray.util.ml_utils.checkpoint_manager import logger as cp_logger - - with patch.object(cp_logger, "error") as log_error_mock: + with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. From 4f9ace796a76ca67de065d3c333d95fc4fba042a Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 10:39:00 +0100 Subject: [PATCH 08/81] Default empty memory checkpoint --- python/ray/tune/checkpoint_manager.py | 6 +++++- python/ray/tune/syncer.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index cb57524225a5..e78163539dd7 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -102,7 +102,11 @@ def newest_checkpoint(self): @property def newest_memory_checkpoint(self): - return self._latest_memory_checkpoint + return self._latest_memory_checkpoint or _TrackedCheckpoint( + checkpoint_dir_or_data=None, + checkpoint_id=0, + storage_mode=_TrackedCheckpoint.MEMORY, + ) def best_checkpoints(self): """Returns best PERSISTENT checkpoints, sorted by score.""" diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index a24b63700d45..b76a6ec4c0b1 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -522,6 +522,9 @@ def _create_trial_syncer(self, trial: "Trial"): trial.logdir, remote_dir=trial.logdir, sync_function=self._sync_function ) + def _remove_trial_syncer(self, trial: "Trial"): + self._syncers.pop(trial, None) + def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint): if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: return @@ -604,8 +607,10 @@ def on_trial_complete( else: trainable_ip = ray.get(trial.runner.get_current_ip.remote()) trial_syncer.set_worker_ip(trainable_ip) - trial_syncer.sync_down_if_needed() + # Always sync down when trial completed + trial_syncer.sync_down() trial_syncer.close() + self._remove_trial_syncer(trial) def on_checkpoint( self, From 114b5b9e8a7e049297fec5abdbefb9717eefedf0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 14:46:30 +0100 Subject: [PATCH 09/81] Default train checkpoint strategy --- python/ray/train/checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 8a6fdbf9a520..205b01dd0583 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -160,10 +160,12 @@ def _get_next_checkpoint_path(self) -> Optional[Path]: def on_start_training( self, - checkpoint_strategy: CheckpointStrategy, + checkpoint_strategy: Optional[CheckpointStrategy], run_dir: str, latest_checkpoint_id: int, ): + checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + if checkpoint_strategy.checkpoint_score_attribute == TUNE_TRAINING_ITERATION: checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP From c4eb5f2eb359b7c24b59a55d4e486954853599f8 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 14:51:10 +0100 Subject: [PATCH 10/81] checkpoint.value --> checkpoint.checkpoint_dir_or_data --- doc/source/tune/api_docs/trainable.rst | 2 +- .../tune/examples/tune-pytorch-cifar.ipynb | 2 +- .../tune-serve-integration-mnist.ipynb | 4 ++-- python/ray/train/tests/test_tune.py | 8 +++++--- python/ray/tune/examples/cifar10_pytorch.py | 2 +- python/ray/tune/result_grid.py | 6 ++++-- python/ray/tune/tests/ext_pytorch.py | 2 +- python/ray/tune/tests/test_cluster.py | 2 +- .../tune/tests/test_experiment_analysis.py | 4 +++- python/ray/tune/tests/test_function_api.py | 20 +++++++++++++------ python/ray/tune/tests/test_trial_runner_2.py | 8 ++++---- python/ray/tune/tests/test_trial_scheduler.py | 2 +- .../bandit/tune_lin_ts_train_wheel_env.py | 2 +- 13 files changed, 39 insertions(+), 25 deletions(-) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 5ac3eedd3e70..adb3032ef7a7 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -142,7 +142,7 @@ You can restore a single trial checkpoint by using ``tune.run(restore= Optional[Union[TuneError, RayTaskError] return None def _trial_to_result(self, trial: Trial) -> Result: - if trial.checkpoint.value: - checkpoint_dir = TrainableUtil.find_checkpoint_dir(trial.checkpoint.value) + if trial.checkpoint.checkpoint_dir_or_data: + checkpoint_dir = TrainableUtil.find_checkpoint_dir( + trial.checkpoint.checkpoint_dir_or_data + ) checkpoint = Checkpoint.from_directory(checkpoint_dir) else: checkpoint = None diff --git a/python/ray/tune/tests/ext_pytorch.py b/python/ray/tune/tests/ext_pytorch.py index b8864f9dcd36..a602157f1051 100644 --- a/python/ray/tune/tests/ext_pytorch.py +++ b/python/ray/tune/tests/ext_pytorch.py @@ -424,7 +424,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) - best_checkpoint_dir = best_trial.checkpoint.value + best_checkpoint_dir = best_trial.checkpoint.checkpoint_dir_or_data model_state, optimizer_state = torch.load(os.path.join( best_checkpoint_dir, "checkpoint")) best_trained_model.load_state_dict(model_state) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index e1914e18fba5..5f5d387f0a11 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -394,7 +394,7 @@ def hidden_path_func(checkpoint_path): cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() - shutil.rmtree(os.path.dirname(t1.checkpoint.value)) + shutil.rmtree(os.path.dirname(t1.checkpoint.checkpoint_dir_or_data)) while not runner.is_finished(): runner.step() assert t1.status == Trial.TERMINATED, runner.debug_string() diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index eabbbd1576c7..84fcde3006a6 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -309,7 +309,9 @@ def train(config): self.assertEqual(ea.best_trial, trials[2]) self.assertEqual(ea.best_config, trials[2].config) self.assertEqual(ea.best_logdir, trials[2].logdir) - self.assertEqual(ea.best_checkpoint._local_path, trials[2].checkpoint.value) + self.assertEqual( + ea.best_checkpoint._local_path, trials[2].checkpoint.checkpoint_dir_or_data + ) self.assertTrue(all(ea.best_dataframe["trial_id"] == trials[2].trial_id)) self.assertEqual(ea.results_df.loc[trials[2].trial_id, "res"], 309) self.assertEqual(ea.best_result["res"], 309) diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index a96d1760ebe7..230ba0f6be00 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -321,7 +321,9 @@ def train(config, checkpoint_dir=False): f.write("hello") [trial] = tune.run(train).trials - assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) + assert os.path.exists( + os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") + ) def testCheckpointFunctionAtEndContext(self): def train(config, checkpoint_dir=False): @@ -333,7 +335,9 @@ def train(config, checkpoint_dir=False): f.write("hello") [trial] = tune.run(train).trials - assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) + assert os.path.exists( + os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") + ) def testVariousCheckpointFunctionAtEnd(self): def train(config, checkpoint_dir=False): @@ -349,7 +353,9 @@ def train(config, checkpoint_dir=False): f.write("goodbye") [trial] = tune.run(train, keep_checkpoints_num=3).trials - assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log2")) + assert os.path.exists( + os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log2") + ) def testReuseCheckpoint(self): def train(config, checkpoint_dir=None): @@ -369,8 +375,10 @@ def train(config, checkpoint_dir=None): train, config={"max_iter": 5}, ).trials - last_ckpt = trial.checkpoint.value - assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) + last_ckpt = trial.checkpoint.checkpoint_dir_or_data + assert os.path.exists( + os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") + ) analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5 @@ -393,7 +401,7 @@ def train(config, checkpoint_dir=None): tune.report(test=i, training_iteration=i) analysis = tune.run(train, max_failures=3) - last_ckpt = analysis.trials[0].checkpoint.value + last_ckpt = analysis.trials[0].checkpoint.checkpoint_dir_or_data assert os.path.exists(os.path.join(last_ckpt, "ckpt.log")) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 10 diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 8832274976bc..545b45b59d30 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -211,7 +211,7 @@ def testCheckpointing(self): self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) runner.step() # Process result, dispatch save runner.step() # Process save, stop trial - kwargs["restore_path"] = trials[0].checkpoint.value + kwargs["restore_path"] = trials[0].checkpoint.checkpoint_dir_or_data self.assertEqual(trials[0].status, Trial.TERMINATED) runner.add_trial(Trial("__fake", **kwargs)) @@ -224,7 +224,7 @@ def testCheckpointing(self): self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) - self.addCleanup(os.remove, trials[0].checkpoint.value) + self.addCleanup(os.remove, trials[0].checkpoint.checkpoint_dir_or_data) def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) @@ -244,7 +244,7 @@ def testRestoreMetricsAfterCheckpointing(self): self.assertEqual(trials[0].status, Trial.TERMINATED) - kwargs["restore_path"] = trials[0].checkpoint.value + kwargs["restore_path"] = trials[0].checkpoint.checkpoint_dir_or_data kwargs.pop("stopping_criterion") kwargs.pop("checkpoint_freq") # No checkpointing for next trial runner.add_trial(Trial("__fake", **kwargs)) @@ -263,7 +263,7 @@ def testRestoreMetricsAfterCheckpointing(self): self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertGreater(trials[1].last_result["time_since_restore"], 0) - self.addCleanup(os.remove, trials[0].checkpoint.value) + self.addCleanup(os.remove, trials[0].checkpoint.checkpoint_dir_or_data) def testCheckpointingAtEnd(self): ray.init(num_cpus=1, num_gpus=1) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index ddb0c9781929..a7e2ffb6a6f4 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -843,7 +843,7 @@ def __init__(self, i, config): self._default_result_or_future = None def on_checkpoint(self, checkpoint): - self.restored_checkpoint = checkpoint.value + self.restored_checkpoint = checkpoint.checkpoint_dir_or_data @property def checkpoint(self): diff --git a/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py b/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py index f4de70d5b77d..8e8a0ab88552 100644 --- a/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py +++ b/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py @@ -86,7 +86,7 @@ def plot_model_weights(means, covs, ax): # Restore trainer from checkpoint trial = analysis.trials[0] trainer = BanditLinTSTrainer(config=config) - trainer.restore(trial.checkpoint.value) + trainer.restore(trial.checkpoint.checkpoint_dir_or_data) # Get model to plot arm weights distribution model = trainer.get_policy().model From 2c671658dafd7a8341c63a962dab9a538f0ca9f0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 16:19:13 +0100 Subject: [PATCH 11/81] init latest checkpoint id --- python/ray/train/checkpoint.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 205b01dd0583..bd2d665e10e4 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -60,7 +60,12 @@ def delete(self): @classmethod def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): new_checkpoint = cls( - **{**checkpoint.__dict__, "storage_mode": _TrackedCheckpoint.PERSISTENT} + checkpoint_dir_or_data=checkpoint.checkpoint_dir_or_data, + storage_mode=_TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.checkpoint_id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + delete_fn=checkpoint._delete_fn, ) return new_checkpoint @@ -162,7 +167,7 @@ def on_start_training( self, checkpoint_strategy: Optional[CheckpointStrategy], run_dir: str, - latest_checkpoint_id: int, + latest_checkpoint_id: Optional[int] = 0, ): checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() @@ -171,7 +176,7 @@ def on_start_training( self._checkpoint_strategy = checkpoint_strategy self.run_dir = run_dir - self._latest_checkpoint_id = latest_checkpoint_id + self._latest_checkpoint_id = latest_checkpoint_id or 0 # Train-specific attributes @property From 8a37de411e7e7ed7406ca33a659d679d674a6dea Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 17:09:49 +0100 Subject: [PATCH 12/81] train latest_checkpoint property --- python/ray/train/checkpoint.py | 4 ++++ python/ray/util/ml_utils/checkpoint_manager.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index bd2d665e10e4..008b54d51823 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -179,6 +179,10 @@ def on_start_training( self._latest_checkpoint_id = latest_checkpoint_id or 0 # Train-specific attributes + @property + def latest_checkpoint(self): + return self._latest_memory_checkpoint.checkpoint_dir_or_data + @property def latest_checkpoint_dir(self) -> Optional[Path]: """Path to the latest checkpoint directory.""" diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index c054b0e67667..22451f878dd0 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -138,8 +138,6 @@ def __init__( ): self._checkpoint_strategy = checkpoint_strategy - self.latest_checkpoint = None - # Incremental unique checkpoint ID of this run. self._latest_checkpoint_id = latest_checkpoint_id From 5b218cd1d5973ddf51292215f0c5e30efae363e4 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 20:31:50 +0100 Subject: [PATCH 13/81] Keep top checkpoints when keep_num is None --- python/ray/tune/analysis/experiment_analysis.py | 3 ++- python/ray/util/ml_utils/checkpoint_manager.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 5d1bd3621f1e..28c7e330ae83 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -423,7 +423,8 @@ def get_trial_checkpoints_paths( # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [ - (c.value, unflattened_lookup(metric, c.result)) for c in checkpoints + (c.checkpoint_dir_or_data, unflattened_lookup(metric, c.result)) + for c in checkpoints ] else: raise ValueError("trial should be a string or a Trial instance.") diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 22451f878dd0..bf15a33b88d1 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -230,6 +230,7 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): # Keep all checkpoints checkpoint.commit(path=self._get_next_checkpoint_path()) self._replace_latest_persisted_checkpoint(checkpoint) + self._top_persisted_checkpoints.append(wrapped_checkpoint) elif ( len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep ): From ef489c2a145b2ab0d0237c226a50c4fc643da3a5 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 20:52:29 +0100 Subject: [PATCH 14/81] Fix train checkpoint bookkeeping and serialization --- python/ray/train/checkpoint.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 008b54d51823..03b9ced0ed4a 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -35,10 +35,15 @@ class _NotYetPersistedCheckpoint(_TrackedCheckpoint): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._committed = False + self._data_to_commit = self.checkpoint_dir_or_data + self.checkpoint_dir_or_data = None + + @property + def committed(self): + return not self._data_to_commit def commit(self, path: Optional[Path] = None): - if self._committed: + if self.committed: return assert path @@ -47,10 +52,10 @@ def commit(self, path: Optional[Path] = None): path.parent.mkdir(parents=True, exist_ok=True) # Write checkpoint to disk. with path.open("wb") as f: - cloudpickle.dump(self, f) + cloudpickle.dump(self._data_to_commit, f) logger.debug(f"Checkpoint successfully written to: {path}") - self._committed = True + self.checkpoint_dir_or_data = path def delete(self): if not self._committed: @@ -181,6 +186,8 @@ def on_start_training( # Train-specific attributes @property def latest_checkpoint(self): + if not self._latest_memory_checkpoint: + return None return self._latest_memory_checkpoint.checkpoint_dir_or_data @property @@ -252,7 +259,7 @@ def _decide_what_to_do_with_checkpoint( # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) with file_path.open("wb") as f: - cloudpickle.dump(checkpoint, f) + cloudpickle.dump(checkpoint.checkpoint_dir_or_data, f) checkpoint._committed = True From d2d7aaeb5183b1a2d763c51d0411f2eb1f4465ed Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 20:59:18 +0100 Subject: [PATCH 15/81] Default checkpoint id --- python/ray/tune/checkpoint_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index e78163539dd7..9c6d8cbd4b6a 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -87,7 +87,7 @@ def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): def newest_persistent_checkpoint(self): return self._latest_persisted_checkpoint or _TrackedCheckpoint( checkpoint_dir_or_data=None, - checkpoint_id=0, + checkpoint_id=-1, storage_mode=_TrackedCheckpoint.PERSISTENT, ) @@ -104,7 +104,7 @@ def newest_checkpoint(self): def newest_memory_checkpoint(self): return self._latest_memory_checkpoint or _TrackedCheckpoint( checkpoint_dir_or_data=None, - checkpoint_id=0, + checkpoint_id=-1, storage_mode=_TrackedCheckpoint.MEMORY, ) From ea9eb504a564e5f9ef0ac9ce6b6301163bea1a9f Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 22:20:50 +0100 Subject: [PATCH 16/81] Fix TuneCheckpointManager --- python/ray/train/checkpoint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 03b9ced0ed4a..8d256cb41a6f 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -251,16 +251,19 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): def _decide_what_to_do_with_checkpoint( self, checkpoint: _NotYetPersistedCheckpoint ): - self.add_tune_checkpoint_id(checkpoint.checkpoint_dir_or_data) + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - with file_path.open("wb") as f: - cloudpickle.dump(checkpoint.checkpoint_dir_or_data, f) - checkpoint._committed = True + checkpoint.commit(file_path) + + return super()._decide_what_to_do_with_checkpoint(checkpoint) def construct_checkpoint_file_name(checkpoint_id: int) -> str: From 74f29285d15abee9c99addb69efba687b52da04e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 22:54:15 +0100 Subject: [PATCH 17/81] Fix latest checkpoint id increment --- python/ray/train/checkpoint.py | 2 ++ python/ray/tune/tests/test_trial_runner_2.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 8d256cb41a6f..c699e9353b50 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -163,6 +163,8 @@ def _process_checkpoint( ) self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) + self._latest_checkpoint_id += 1 + def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 545b45b59d30..ecbaa8137cb2 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -339,7 +339,6 @@ def write_checkpoint(trial: Trial, index: int): result=result, ) trial.saving_to = tune_cp - trial.on_checkpoint(tune_cp) return checkpoint_dir From bb4ede2a59598ef909556f12dae16657e3ad8c62 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 May 2022 22:57:26 +0100 Subject: [PATCH 18/81] Update delete fn --- python/ray/tune/checkpoint_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 9c6d8cbd4b6a..ebdba2bc3c86 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -56,11 +56,15 @@ def __init__( super().__init__(checkpoint_strategy=checkpoint_strategy) def on_checkpoint(self, checkpoint: _TrackedCheckpoint): + # Set checkpoint ID checkpoint.checkpoint_id = ( checkpoint.checkpoint_id or self._latest_checkpoint_id ) self._latest_checkpoint_id += 1 + # Set delete fn + checkpoint._delete_fn = checkpoint._delete_fn or self._delete_fn + if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) else: From a92044e5970174ffbe513d5289352bdcfac6c31f Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 May 2022 10:21:35 +0100 Subject: [PATCH 19/81] Delete fn should be property of checkpoint manager, not TrackedCheckpoint --- python/ray/train/checkpoint.py | 3 +- python/ray/tune/checkpoint_manager.py | 21 ++------------ python/ray/tune/tests/test_trial_runner_2.py | 8 ++++-- python/ray/tune/trial.py | 4 +-- .../ray/util/ml_utils/checkpoint_manager.py | 28 ++++++++++++++----- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index c699e9353b50..1e43dc393087 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -58,7 +58,7 @@ def commit(self, path: Optional[Path] = None): self.checkpoint_dir_or_data = path def delete(self): - if not self._committed: + if not self.committed: return return super().delete() @@ -70,7 +70,6 @@ def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): checkpoint_id=checkpoint.checkpoint_id, result=checkpoint.result, node_ip=checkpoint.node_ip, - delete_fn=checkpoint._delete_fn, ) return new_checkpoint diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index ebdba2bc3c86..9a7b76eaef58 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -1,6 +1,6 @@ # coding: utf-8 import logging -from typing import Callable +from typing import Callable, Optional from ray.util.ml_utils.checkpoint_manager import ( CheckpointStrategy, @@ -31,7 +31,7 @@ def __init__( self, keep_checkpoints_num: int, checkpoint_score_attr: str, - delete_fn: Callable[[str], None], + delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): if keep_checkpoints_num == 0: raise RuntimeError( @@ -51,9 +51,7 @@ def __init__( checkpoint_score_order=MIN if checkpoint_score_desc else MAX, ) - self._delete_fn = delete_fn - - super().__init__(checkpoint_strategy=checkpoint_strategy) + super().__init__(checkpoint_strategy=checkpoint_strategy, delete_fn=delete_fn) def on_checkpoint(self, checkpoint: _TrackedCheckpoint): # Set checkpoint ID @@ -62,9 +60,6 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint): ) self._latest_checkpoint_id += 1 - # Set delete fn - checkpoint._delete_fn = checkpoint._delete_fn or self._delete_fn - if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) else: @@ -116,13 +111,3 @@ def best_checkpoints(self): """Returns best PERSISTENT checkpoints, sorted by score.""" checkpoints = sorted(self._top_persisted_checkpoints, key=lambda c: c.priority) return [wrapped.tracked_checkpoint for wrapped in checkpoints] - - def __getstate__(self): - state = super().__getstate__() - # Avoid serializing delete fn as it may contain cyclical dependencies - state.pop("_delete_fn", None) - return state - - def __setstate__(self, state): - state["_delete_fn"] = None - super().__setstate__(state) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index ecbaa8137cb2..368fc96981df 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -323,7 +323,9 @@ def testPauseResumeCheckpointCount(self): trial = Trial("__fake", keep_checkpoints_num=2) trial.init_logdir() - trial.checkpoint_manager._delete_fn = lambda cp: shutil.rmtree(cp.value) + trial.checkpoint_manager.set_delete_fn( + lambda cp: shutil.rmtree(cp.checkpoint_dir_or_data) + ) def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir( @@ -378,7 +380,9 @@ def get_checkpoint_dirs(trial: Trial): runner.resume() trial = runner.get_trials()[0] - trial.checkpoint_manager._delete_fn = lambda cp: shutil.rmtree(cp.value) + trial.checkpoint_manager.set_delete_fn( + lambda cp: shutil.rmtree(cp.checkpoint_dir_or_data) + ) # Write fourth checkpoint result = write_checkpoint(trial, 4) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 33ab18665118..dab0e7c21ccb 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -571,8 +571,8 @@ def set_runner(self, runner): self._default_result_or_future = runner.get_auto_filled_metrics.remote( debug_metrics_only=True ) - self.checkpoint_manager._delete_fn = CheckpointDeleter( - self._trainable_name(), runner + self.checkpoint_manager.set_delete_fn( + CheckpointDeleter(self._trainable_name(), runner) ) # No need to invalidate state cache: runner is not stored in json # self.invalidate_json_state() diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index bf15a33b88d1..7efb63787030 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -30,7 +30,6 @@ def __init__( checkpoint_id: Optional[int] = None, result: Optional[Dict] = None, node_ip: Optional[str] = None, - delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): self.checkpoint_dir_or_data = checkpoint_dir_or_data self.checkpoint_id = checkpoint_id @@ -40,13 +39,14 @@ def __init__( self.result = result or {} self.node_ip = node_ip or self.result.get(NODE_IP, None) - self._delete_fn = delete_fn or _default_delete_fn - def commit(self, path: Optional[Path] = None): pass - def delete(self): - self._delete_fn(self) + def delete( + self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None + ): + delete_fn = delete_fn or _default_delete_fn + delete_fn(self) def __repr__(self): if self.storage_mode == _TrackedCheckpoint.MEMORY: @@ -134,7 +134,10 @@ def __post_init__(self): class CheckpointManager: def __init__( - self, checkpoint_strategy: CheckpointStrategy, latest_checkpoint_id: int = 0 + self, + checkpoint_strategy: CheckpointStrategy, + latest_checkpoint_id: int = 0, + delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): self._checkpoint_strategy = checkpoint_strategy @@ -152,6 +155,12 @@ def __init__( # Checkpoints that are not immediately removed self._checkpoints_to_clean_up = set() + self._delete_fn = delete_fn + + def set_delete_fn( + self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] + ): + self._delete_fn = delete_fn def _replace_latest_memory_checkpoint(self, memory_checkpoint: _TrackedCheckpoint): assert memory_checkpoint.storage_mode == _TrackedCheckpoint.MEMORY @@ -270,7 +279,7 @@ def _maybe_delete_persisted_checkpoint( self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): - persisted_checkpoint.delete() + persisted_checkpoint.delete(delete_fn=self._delete_fn) self._checkpoints_to_clean_up.discard(persisted_checkpoint) def _cleanup_checkpoints(self): @@ -286,6 +295,10 @@ def _get_next_checkpoint_path(self) -> Optional[Path]: def __getstate__(self): state = self.__dict__.copy() + + # Do not serialize the delete fn + state.pop("_delete_fn", None) + # Avoid serializing the memory checkpoint. state["_newest_memory_checkpoint"] = _TrackedCheckpoint( checkpoint_dir_or_data=None, @@ -295,4 +308,5 @@ def __getstate__(self): return state def __setstate__(self, state): + state["_delete_fn"] = None self.__dict__.update(state) From 5edffe5b7c267ba3f84702fed33c4464ada5911e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 May 2022 12:34:29 +0100 Subject: [PATCH 20/81] fix train checkpoint deletion --- python/ray/train/checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 1e43dc393087..362d778e04f0 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -57,10 +57,12 @@ def commit(self, path: Optional[Path] = None): self.checkpoint_dir_or_data = path - def delete(self): + def delete( + self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None + ): if not self.committed: return - return super().delete() + return super().delete(delete_fn=delete_fn) @classmethod def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): From 01a680806529df19db328b63ec1e96bd03d2815d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 May 2022 15:29:41 +0100 Subject: [PATCH 21/81] Clear data on commit --- python/ray/train/checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 362d778e04f0..9bd3ce5afadf 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -56,6 +56,7 @@ def commit(self, path: Optional[Path] = None): logger.debug(f"Checkpoint successfully written to: {path}") self.checkpoint_dir_or_data = path + self._data_to_commit = None def delete( self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None From 98aff7dc7c14fd68029d7953714fed0baefda2fd Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 May 2022 16:26:42 +0100 Subject: [PATCH 22/81] Pre-review --- python/ray/train/checkpoint.py | 14 +++++++++++--- python/ray/util/ml_utils/checkpoint_manager.py | 5 ++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 9bd3ce5afadf..02f93a1e98cf 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -3,8 +3,13 @@ from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID, TIMESTAMP -from ray.train.constants import TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR +from ray.train.constants import ( + TRAIN_CHECKPOINT_SUBDIR, + TRAINING_ITERATION, + TUNE_CHECKPOINT_FILE_NAME, + TUNE_CHECKPOINT_ID, + TUNE_INSTALLED, +) from ray.train.session import TrainingResult from ray.train.utils import construct_path from ray.tune.result import TRAINING_ITERATION as TUNE_TRAINING_ITERATION @@ -180,8 +185,11 @@ def on_start_training( ): checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + # We only want to support one CheckpointStrategy object. Thus, + # for Ray Train we update the default score attribute for Ray Train's + # version. if checkpoint_strategy.checkpoint_score_attribute == TUNE_TRAINING_ITERATION: - checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP + checkpoint_strategy.checkpoint_score_attribute = TRAINING_ITERATION self._checkpoint_strategy = checkpoint_strategy self.run_dir = run_dir diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 7efb63787030..7e3dd8bc1f8e 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -4,6 +4,7 @@ import numbers import os import shutil + from dataclasses import dataclass from pathlib import Path from typing import Optional, Dict, Union, Callable, Tuple, List, Any @@ -40,11 +41,13 @@ def __init__( self.node_ip = node_ip or self.result.get(NODE_IP, None) def commit(self, path: Optional[Path] = None): + """Commit checkpoint to disk, if needed.""" pass def delete( self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None ): + """Delete checkpoint from disk, if needed.""" delete_fn = delete_fn or _default_delete_fn delete_fn(self) @@ -90,7 +93,7 @@ def __repr__(self): @PublicAPI(stability="beta") @dataclass class CheckpointStrategy: - """Configurable parameters for defining the Train checkpointing strategy. + """Configurable parameters for defining the checkpointing strategy. Default behavior is to persist all checkpoints to disk. If ``num_to_keep`` is set, the default retention policy is to keep the From cba0450604d962a912e28624527e44e1d00b5422 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 May 2022 18:03:05 +0100 Subject: [PATCH 23/81] training iteration -> timestamp --- python/ray/train/checkpoint.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 02f93a1e98cf..8befafcbdfb8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -4,8 +4,8 @@ from ray import cloudpickle from ray.train.constants import ( + TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, - TRAINING_ITERATION, TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID, TUNE_INSTALLED, @@ -186,10 +186,9 @@ def on_start_training( checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() # We only want to support one CheckpointStrategy object. Thus, - # for Ray Train we update the default score attribute for Ray Train's - # version. + # for Ray Train we update the default score attribute for Ray Train if checkpoint_strategy.checkpoint_score_attribute == TUNE_TRAINING_ITERATION: - checkpoint_strategy.checkpoint_score_attribute = TRAINING_ITERATION + checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP self._checkpoint_strategy = checkpoint_strategy self.run_dir = run_dir From 0363b4dd0efcd3d2503b47806ef45f06f882d3b6 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 6 May 2022 12:15:49 +0100 Subject: [PATCH 24/81] Update python/ray/util/ml_utils/checkpoint_manager.py Co-authored-by: Antoni Baum --- python/ray/util/ml_utils/checkpoint_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 7e3dd8bc1f8e..85943de3bacc 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -125,7 +125,7 @@ class CheckpointStrategy: def __post_init__(self): if self.num_to_keep is not None and self.num_to_keep < 0: raise ValueError( - f"Received invalidate num_to_keep: " + f"Received invalid num_to_keep: " f"{self.num_to_keep}. " f"Must be None or non-negative integer." ) From 388dddef15d24884b9262804fa9954898c13f212 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 6 May 2022 12:26:23 +0100 Subject: [PATCH 25/81] Rename dataclass attributes --- doc/source/tune/api_docs/trainable.rst | 2 +- .../tune/examples/tune-pytorch-cifar.ipynb | 2 +- .../tune-serve-integration-mnist.ipynb | 4 +- python/ray/train/checkpoint.py | 23 ++++++---- python/ray/train/tests/test_tune.py | 6 +-- .../ray/tune/analysis/experiment_analysis.py | 2 +- python/ray/tune/checkpoint_manager.py | 10 ++-- python/ray/tune/examples/cifar10_pytorch.py | 2 +- python/ray/tune/ray_trial_executor.py | 8 ++-- python/ray/tune/result_grid.py | 4 +- python/ray/tune/syncer.py | 4 +- python/ray/tune/tests/ext_pytorch.py | 2 +- .../ray/tune/tests/test_checkpoint_manager.py | 44 +++++++++--------- python/ray/tune/tests/test_cluster.py | 2 +- .../tune/tests/test_experiment_analysis.py | 2 +- python/ray/tune/tests/test_function_api.py | 20 +++----- .../ray/tune/tests/test_ray_trial_executor.py | 4 +- python/ray/tune/tests/test_trial_runner_2.py | 18 +++----- .../tune/tests/test_trial_runner_callbacks.py | 2 +- python/ray/tune/tests/test_trial_scheduler.py | 6 +-- .../tune/tests/test_trial_scheduler_pbt.py | 2 +- .../test_trial_scheduler_resource_changing.py | 2 +- python/ray/tune/trial.py | 12 ++--- python/ray/tune/trial_runner.py | 2 +- .../ray/util/ml_utils/checkpoint_manager.py | 46 ++++++++++++------- .../bandit/tune_lin_ts_train_wheel_env.py | 2 +- 26 files changed, 119 insertions(+), 114 deletions(-) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index adb3032ef7a7..4d05c9353918 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -142,7 +142,7 @@ You can restore a single trial checkpoint by using ``tune.run(restore= Dict: class _NotYetPersistedCheckpoint(_TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._data_to_commit = self.checkpoint_dir_or_data - self.checkpoint_dir_or_data = None + self._data_to_commit = self.dir_or_data + self.dir_or_data = None @property def committed(self): @@ -60,7 +67,7 @@ def commit(self, path: Optional[Path] = None): cloudpickle.dump(self._data_to_commit, f) logger.debug(f"Checkpoint successfully written to: {path}") - self.checkpoint_dir_or_data = path + self.dir_or_data = path self._data_to_commit = None def delete( @@ -73,9 +80,9 @@ def delete( @classmethod def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): new_checkpoint = cls( - checkpoint_dir_or_data=checkpoint.checkpoint_dir_or_data, + dir_or_data=checkpoint.dir_or_data, storage_mode=_TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.checkpoint_id, + checkpoint_id=checkpoint.id, result=checkpoint.result, node_ip=checkpoint.node_ip, ) @@ -154,7 +161,7 @@ def _process_checkpoint( ) tracked_checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=checkpoint_data, + dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=_TrackedCheckpoint.MEMORY, result={score_attr: checkpoint_data.get(score_attr, 0.0)}, @@ -199,7 +206,7 @@ def on_start_training( def latest_checkpoint(self): if not self._latest_memory_checkpoint: return None - return self._latest_memory_checkpoint.checkpoint_dir_or_data + return self._latest_memory_checkpoint.dir_or_data @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -225,7 +232,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return Path(self._best_persisted_checkpoint.checkpoint_dir_or_data) + return Path(self._best_persisted_checkpoint.dir_or_data) else: return None diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index 619a9808ed39..49576fa52243 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -118,7 +118,7 @@ def train_func(): [trial] = tune.run(TestTrainable).trials checkpoint_file = os.path.join( - trial.checkpoint.checkpoint_dir_or_data, TUNE_CHECKPOINT_FILE_NAME + trial.checkpoint.dir_or_data, TUNE_CHECKPOINT_FILE_NAME ) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: @@ -141,7 +141,7 @@ def train_func(config): TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable, config={"max_iter": 5}).trials - last_ckpt = trial.checkpoint.checkpoint_dir_or_data + last_ckpt = trial.checkpoint.dir_or_data checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: @@ -170,7 +170,7 @@ def train_func(): TestTrainable = trainer.to_tune_trainable(train_func) analysis = tune.run(TestTrainable, max_failures=3) - last_ckpt = analysis.trials[0].checkpoint.checkpoint_dir_or_data + last_ckpt = analysis.trials[0].checkpoint.dir_or_data checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 28c7e330ae83..546a18357154 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -423,7 +423,7 @@ def get_trial_checkpoints_paths( # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [ - (c.checkpoint_dir_or_data, unflattened_lookup(metric, c.result)) + (c.dir_or_data, unflattened_lookup(metric, c.result)) for c in checkpoints ] else: diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 9a7b76eaef58..d5cde13a0ac2 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -55,9 +55,7 @@ def __init__( def on_checkpoint(self, checkpoint: _TrackedCheckpoint): # Set checkpoint ID - checkpoint.checkpoint_id = ( - checkpoint.checkpoint_id or self._latest_checkpoint_id - ) + checkpoint.id = checkpoint.id or self._latest_checkpoint_id self._latest_checkpoint_id += 1 if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: @@ -85,7 +83,7 @@ def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): @property def newest_persistent_checkpoint(self): return self._latest_persisted_checkpoint or _TrackedCheckpoint( - checkpoint_dir_or_data=None, + dir_or_data=None, checkpoint_id=-1, storage_mode=_TrackedCheckpoint.PERSISTENT, ) @@ -95,14 +93,14 @@ def newest_checkpoint(self): """Returns the newest checkpoint (based on training iteration).""" newest_checkpoint = max( [self.newest_persistent_checkpoint, self.newest_memory_checkpoint], - key=lambda c: c.checkpoint_id, + key=lambda c: c.id, ) return newest_checkpoint @property def newest_memory_checkpoint(self): return self._latest_memory_checkpoint or _TrackedCheckpoint( - checkpoint_dir_or_data=None, + dir_or_data=None, checkpoint_id=-1, storage_mode=_TrackedCheckpoint.MEMORY, ) diff --git a/python/ray/tune/examples/cifar10_pytorch.py b/python/ray/tune/examples/cifar10_pytorch.py index 6d2c9b948f77..3e56f8bef729 100644 --- a/python/ray/tune/examples/cifar10_pytorch.py +++ b/python/ray/tune/examples/cifar10_pytorch.py @@ -165,7 +165,7 @@ def test_best_model(best_trial): device = "cuda:0" if torch.cuda.is_available() else "cpu" best_trained_model.to(device) - checkpoint_path = os.path.join(best_trial.checkpoint.checkpoint_dir_or_data, "checkpoint") + checkpoint_path = os.path.join(best_trial.checkpoint.dir_or_data, "checkpoint") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index e66b778a7691..79ea74e13cc7 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -693,13 +693,13 @@ def save( if storage == _TrackedCheckpoint.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=value, storage_mode=storage, result=result + dir_or_data=value, storage_mode=storage, result=result ) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=value, storage_mode=storage, result=result + dir_or_data=value, storage_mode=storage, result=result ) trial.saving_to = checkpoint self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial) @@ -717,13 +717,13 @@ def restore(self, trial: Trial) -> None: ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint - if checkpoint.checkpoint_dir_or_data is None: + if checkpoint.dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial) ) - value = checkpoint.checkpoint_dir_or_data + value = checkpoint.dir_or_data node_ip = checkpoint.node_ip if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) diff --git a/python/ray/tune/result_grid.py b/python/ray/tune/result_grid.py index 0c3629495376..3df86fa86ac4 100644 --- a/python/ray/tune/result_grid.py +++ b/python/ray/tune/result_grid.py @@ -165,9 +165,9 @@ def _populate_exception(trial: Trial) -> Optional[Union[TuneError, RayTaskError] return None def _trial_to_result(self, trial: Trial) -> Result: - if trial.checkpoint.checkpoint_dir_or_data: + if trial.checkpoint.dir_or_data: checkpoint_dir = TrainableUtil.find_checkpoint_dir( - trial.checkpoint.checkpoint_dir_or_data + trial.checkpoint.dir_or_data ) checkpoint = Checkpoint.from_directory(checkpoint_dir) else: diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index b76a6ec4c0b1..4a117d1af72f 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -565,7 +565,7 @@ def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint) # shouldn't track it with the checkpoint_manager. raise e if not trial.uses_cloud_checkpointing: - if not os.path.exists(checkpoint.checkpoint_dir_or_data): + if not os.path.exists(checkpoint.dir_or_data): raise TuneError( "Trial {}: Checkpoint path {} not " "found after successful sync down. " @@ -576,7 +576,7 @@ def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint) "if that's the case, see instructions " "here: {} .".format( trial, - checkpoint.checkpoint_dir_or_data, + checkpoint.dir_or_data, CLOUD_CHECKPOINTING_URL, ) ) diff --git a/python/ray/tune/tests/ext_pytorch.py b/python/ray/tune/tests/ext_pytorch.py index a602157f1051..32ad3d3e4eeb 100644 --- a/python/ray/tune/tests/ext_pytorch.py +++ b/python/ray/tune/tests/ext_pytorch.py @@ -424,7 +424,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) - best_checkpoint_dir = best_trial.checkpoint.checkpoint_dir_or_data + best_checkpoint_dir = best_trial.checkpoint.dir_or_data model_state, optimizer_state = torch.load(os.path.join( best_checkpoint_dir, "checkpoint")) best_trained_model.load_state_dict(model_state) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index bd82b6216acb..b5d8a838e7b9 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -26,13 +26,13 @@ def checkpoint_manager(self, keep_checkpoints_num): def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data={0}, + dir_or_data={0}, storage_mode=_TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data={1}, + dir_or_data={1}, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(1, 1), ) @@ -50,7 +50,7 @@ def testOnCheckpointOrdered(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data={i}, + dir_or_data={i}, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) @@ -82,7 +82,7 @@ def testOnCheckpointUnordered(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data={i}, + dir_or_data={i}, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) @@ -116,7 +116,7 @@ def testBestCheckpoints(self): keep_checkpoints_num = 4 checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data=i, + dir_or_data=i, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) @@ -132,7 +132,7 @@ def testBestCheckpoints(self): best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): - self.assertEqual(best_checkpoints[i].checkpoint_dir_or_data, i + 4) + self.assertEqual(best_checkpoints[i].dir_or_data, i + 4) def testBestCheckpointsWithNan(self): """ @@ -141,14 +141,14 @@ def testBestCheckpointsWithNan(self): keep_checkpoints_num = 2 checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data=None, + dir_or_data=None, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ _TrackedCheckpoint( - checkpoint_dir_or_data=3, + dir_or_data=3, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(0, 3), ) @@ -162,8 +162,8 @@ def testBestCheckpointsWithNan(self): best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) - self.assertEqual(best_checkpoints[0].checkpoint_dir_or_data, None) - self.assertEqual(best_checkpoints[1].checkpoint_dir_or_data, 3) + self.assertEqual(best_checkpoints[0].dir_or_data, None) + self.assertEqual(best_checkpoints[1].dir_or_data, 3) def testBestCheckpointsOnlyNan(self): """ @@ -173,7 +173,7 @@ def testBestCheckpointsOnlyNan(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data=i, + dir_or_data=i, storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(float("nan"), i), ) @@ -186,8 +186,8 @@ def testBestCheckpointsOnlyNan(self): best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) - self.assertEqual(best_checkpoints[0].checkpoint_dir_or_data, 2) - self.assertEqual(best_checkpoints[1].checkpoint_dir_or_data, 3) + self.assertEqual(best_checkpoints[0].dir_or_data, 2) + self.assertEqual(best_checkpoints[1].dir_or_data, 3) def testOnCheckpointUnavailableAttribute(self): """ @@ -197,7 +197,7 @@ def testOnCheckpointUnavailableAttribute(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) no_attr_checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=0, + dir_or_data=0, storage_mode=_TrackedCheckpoint.PERSISTENT, result={}, ) @@ -213,12 +213,12 @@ def testOnCheckpointUnavailableAttribute(self): def testOnMemoryCheckpoint(self): checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data=0, + dir_or_data=0, storage_mode=_TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ), _TrackedCheckpoint( - checkpoint_dir_or_data=0, + dir_or_data=0, storage_mode=_TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ), @@ -235,7 +235,7 @@ def testSameCheckpoint(self): checkpoint_manager = CheckpointManager( keep_checkpoints_num=1, checkpoint_score_attr="i", - delete_fn=lambda c: os.remove(c.checkpoint_dir_or_data), + delete_fn=lambda c: os.remove(c.dir_or_data), ) tmpfiles = [] @@ -247,29 +247,29 @@ def testSameCheckpoint(self): checkpoints = [ _TrackedCheckpoint( - checkpoint_dir_or_data=tmpfiles[0], + dir_or_data=tmpfiles[0], storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(5, 5), ), _TrackedCheckpoint( - checkpoint_dir_or_data=tmpfiles[1], + dir_or_data=tmpfiles[1], storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(10, 10), ), _TrackedCheckpoint( - checkpoint_dir_or_data=tmpfiles[2], + dir_or_data=tmpfiles[2], storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(0, 0), ), _TrackedCheckpoint( - checkpoint_dir_or_data=tmpfiles[1], + dir_or_data=tmpfiles[1], storage_mode=_TrackedCheckpoint.PERSISTENT, result=self.mock_result(20, 20), ), ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) - self.assertTrue(os.path.exists(checkpoint.checkpoint_dir_or_data)) + self.assertTrue(os.path.exists(checkpoint.dir_or_data)) for tmpfile in tmpfiles: if os.path.exists(tmpfile): diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 5f5d387f0a11..5d787779a3af 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -394,7 +394,7 @@ def hidden_path_func(checkpoint_path): cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() - shutil.rmtree(os.path.dirname(t1.checkpoint.checkpoint_dir_or_data)) + shutil.rmtree(os.path.dirname(t1.checkpoint.dir_or_data)) while not runner.is_finished(): runner.step() assert t1.status == Trial.TERMINATED, runner.debug_string() diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 84fcde3006a6..d0a9c0eab59e 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -310,7 +310,7 @@ def train(config): self.assertEqual(ea.best_config, trials[2].config) self.assertEqual(ea.best_logdir, trials[2].logdir) self.assertEqual( - ea.best_checkpoint._local_path, trials[2].checkpoint.checkpoint_dir_or_data + ea.best_checkpoint._local_path, trials[2].checkpoint.dir_or_data ) self.assertTrue(all(ea.best_dataframe["trial_id"] == trials[2].trial_id)) self.assertEqual(ea.results_df.loc[trials[2].trial_id, "res"], 309) diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 230ba0f6be00..087d31d4692f 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -321,9 +321,7 @@ def train(config, checkpoint_dir=False): f.write("hello") [trial] = tune.run(train).trials - assert os.path.exists( - os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") - ) + assert os.path.exists(os.path.join(trial.checkpoint.dir_or_data, "ckpt.log")) def testCheckpointFunctionAtEndContext(self): def train(config, checkpoint_dir=False): @@ -335,9 +333,7 @@ def train(config, checkpoint_dir=False): f.write("hello") [trial] = tune.run(train).trials - assert os.path.exists( - os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") - ) + assert os.path.exists(os.path.join(trial.checkpoint.dir_or_data, "ckpt.log")) def testVariousCheckpointFunctionAtEnd(self): def train(config, checkpoint_dir=False): @@ -353,9 +349,7 @@ def train(config, checkpoint_dir=False): f.write("goodbye") [trial] = tune.run(train, keep_checkpoints_num=3).trials - assert os.path.exists( - os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log2") - ) + assert os.path.exists(os.path.join(trial.checkpoint.dir_or_data, "ckpt.log2")) def testReuseCheckpoint(self): def train(config, checkpoint_dir=None): @@ -375,10 +369,8 @@ def train(config, checkpoint_dir=None): train, config={"max_iter": 5}, ).trials - last_ckpt = trial.checkpoint.checkpoint_dir_or_data - assert os.path.exists( - os.path.join(trial.checkpoint.checkpoint_dir_or_data, "ckpt.log") - ) + last_ckpt = trial.checkpoint.dir_or_data + assert os.path.exists(os.path.join(trial.checkpoint.dir_or_data, "ckpt.log")) analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5 @@ -401,7 +393,7 @@ def train(config, checkpoint_dir=None): tune.report(test=i, training_iteration=i) analysis = tune.run(train, max_failures=3) - last_ckpt = analysis.trials[0].checkpoint.checkpoint_dir_or_data + last_ckpt = analysis.trials[0].checkpoint.dir_or_data assert os.path.exists(os.path.join(last_ckpt, "ckpt.log")) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 10 diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 3e69793cbca9..4c064f0e6af1 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -122,7 +122,7 @@ def _simulate_getting_result(self, trial): def _simulate_saving(self, trial): checkpoint = self.trial_executor.save(trial, _TrackedCheckpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) - self.assertEqual(trial.checkpoint.checkpoint_dir_or_data, None) + self.assertEqual(trial.checkpoint.dir_or_data, None) event = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False ) @@ -376,7 +376,7 @@ def generate_trials(spec, name): def process_trial_save(self, trial, checkpoint_value): """Simulates trial runner save.""" checkpoint = trial.saving_to - checkpoint.checkpoint_dir_or_data = checkpoint_value + checkpoint.dir_or_data = checkpoint_value trial.on_checkpoint(checkpoint) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 368fc96981df..aae34ea076c7 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -211,7 +211,7 @@ def testCheckpointing(self): self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) runner.step() # Process result, dispatch save runner.step() # Process save, stop trial - kwargs["restore_path"] = trials[0].checkpoint.checkpoint_dir_or_data + kwargs["restore_path"] = trials[0].checkpoint.dir_or_data self.assertEqual(trials[0].status, Trial.TERMINATED) runner.add_trial(Trial("__fake", **kwargs)) @@ -224,7 +224,7 @@ def testCheckpointing(self): self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) - self.addCleanup(os.remove, trials[0].checkpoint.checkpoint_dir_or_data) + self.addCleanup(os.remove, trials[0].checkpoint.dir_or_data) def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) @@ -244,7 +244,7 @@ def testRestoreMetricsAfterCheckpointing(self): self.assertEqual(trials[0].status, Trial.TERMINATED) - kwargs["restore_path"] = trials[0].checkpoint.checkpoint_dir_or_data + kwargs["restore_path"] = trials[0].checkpoint.dir_or_data kwargs.pop("stopping_criterion") kwargs.pop("checkpoint_freq") # No checkpointing for next trial runner.add_trial(Trial("__fake", **kwargs)) @@ -263,7 +263,7 @@ def testRestoreMetricsAfterCheckpointing(self): self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertGreater(trials[1].last_result["time_since_restore"], 0) - self.addCleanup(os.remove, trials[0].checkpoint.checkpoint_dir_or_data) + self.addCleanup(os.remove, trials[0].checkpoint.dir_or_data) def testCheckpointingAtEnd(self): ray.init(num_cpus=1, num_gpus=1) @@ -323,9 +323,7 @@ def testPauseResumeCheckpointCount(self): trial = Trial("__fake", keep_checkpoints_num=2) trial.init_logdir() - trial.checkpoint_manager.set_delete_fn( - lambda cp: shutil.rmtree(cp.checkpoint_dir_or_data) - ) + trial.checkpoint_manager.set_delete_fn(lambda cp: shutil.rmtree(cp.dir_or_data)) def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir( @@ -336,7 +334,7 @@ def write_checkpoint(trial: Trial, index: int): json.dump(result, f) tune_cp = _TrackedCheckpoint( - checkpoint_dir_or_data=checkpoint_dir, + dir_or_data=checkpoint_dir, storage_mode=_TrackedCheckpoint.PERSISTENT, result=result, ) @@ -380,9 +378,7 @@ def get_checkpoint_dirs(trial: Trial): runner.resume() trial = runner.get_trials()[0] - trial.checkpoint_manager.set_delete_fn( - lambda cp: shutil.rmtree(cp.checkpoint_dir_or_data) - ) + trial.checkpoint_manager.set_delete_fn(lambda cp: shutil.rmtree(cp.dir_or_data)) # Write fourth checkpoint result = write_checkpoint(trial, 4) diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index c9414f5cbeec..986f409fa67a 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -151,7 +151,7 @@ def testCallbackSteps(self): # Just a placeholder object ref for cp.value. cp = _TrackedCheckpoint( - checkpoint_dir_or_data=ray.put(1), + dir_or_data=ray.put(1), storage_mode=_TrackedCheckpoint.PERSISTENT, result={TRAINING_ITERATION: 0}, ) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index a7e2ffb6a6f4..5b7be548910d 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -251,7 +251,7 @@ def restore(self, trial, checkpoint=None, block=False): def save(self, trial, type=_TrackedCheckpoint.PERSISTENT, result=None): return _TrackedCheckpoint( - checkpoint_dir_or_data=trial.trainable_name, + dir_or_data=trial.trainable_name, storage_mode=_TrackedCheckpoint.PERSISTENT, result=result, ) @@ -843,12 +843,12 @@ def __init__(self, i, config): self._default_result_or_future = None def on_checkpoint(self, checkpoint): - self.restored_checkpoint = checkpoint.checkpoint_dir_or_data + self.restored_checkpoint = checkpoint.dir_or_data @property def checkpoint(self): return _TrackedCheckpoint( - checkpoint_dir_or_data=self.trainable_name, + dir_or_data=self.trainable_name, storage_mode=_TrackedCheckpoint.MEMORY, result=None, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index e53c8150e310..8f25b214b023 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -441,7 +441,7 @@ class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( - checkpoint_dir_or_data="None", + dir_or_data="None", storage_mode=_TrackedCheckpoint.MEMORY, result={}, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index ed56bb3ab815..035069b8fce3 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -50,7 +50,7 @@ class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( - checkpoint_dir_or_data="None", + dir_or_data="None", storage_mode=_TrackedCheckpoint.MEMORY, result={}, ) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index dab0e7c21ccb..633b59b5a546 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -111,9 +111,9 @@ def __call__(self, checkpoint: _TrackedCheckpoint): if ( checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT - and checkpoint.checkpoint_dir_or_data + and checkpoint.dir_or_data ): - checkpoint_path = checkpoint.checkpoint_dir_or_data + checkpoint_path = checkpoint.dir_or_data logger.debug( "Trial %s: Deleting checkpoint %s", self.trial_id, checkpoint_path @@ -466,9 +466,9 @@ def checkpoint(self): checkpoint = self.checkpoint_manager.newest_persistent_checkpoint else: checkpoint = self.checkpoint_manager.newest_checkpoint - if checkpoint.checkpoint_dir_or_data is None: + if checkpoint.dir_or_data is None: checkpoint = _TrackedCheckpoint( - checkpoint_dir_or_data=self.restore_path, + dir_or_data=self.restore_path, storage_mode=_TrackedCheckpoint.PERSISTENT, ) return checkpoint @@ -648,10 +648,10 @@ def should_checkpoint(self): ) def has_checkpoint(self): - return self.checkpoint.checkpoint_dir_or_data is not None + return self.checkpoint.dir_or_data is not None def clear_checkpoint(self): - self.checkpoint.checkpoint_dir_or_data = None + self.checkpoint.dir_or_data = None self.restoring_from = None self.invalidate_json_state() diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 04b392fbe72d..65e70085d475 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -1109,7 +1109,7 @@ def _process_trial_save( logger.debug("Trial %s: Processing trial save.", trial) try: - trial.saving_to.checkpoint_dir_or_data = checkpoint_value + trial.saving_to.dir_or_data = checkpoint_value self._callbacks.on_checkpoint( iteration=self._iteration, trials=self._trials, diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 85943de3bacc..c8c4ed1c5ce6 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -26,17 +26,17 @@ class _TrackedCheckpoint: def __init__( self, - checkpoint_dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], + dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, checkpoint_id: Optional[int] = None, result: Optional[Dict] = None, node_ip: Optional[str] = None, ): - self.checkpoint_dir_or_data = checkpoint_dir_or_data - self.checkpoint_id = checkpoint_id + self.dir_or_data = dir_or_data + self.id = checkpoint_id self.storage_mode = storage_mode - # Todo: What to do if result is a subset of checkpoint_dir_or_data (dict) + # Todo: What to do if result is a subset of dir_or_data (dict) self.result = result or {} self.node_ip = node_ip or self.result.get(NODE_IP, None) @@ -57,7 +57,7 @@ def __repr__(self): return ( f"<_TrackedCheckpoint storage='PERSISTENT' " - f"checkpoint_dir_or_data={self.checkpoint_dir_or_data}>" + f"dir_or_data={self.dir_or_data}>" ) @@ -65,16 +65,16 @@ def _default_delete_fn(checkpoint: _TrackedCheckpoint): if checkpoint.storage_mode != _TrackedCheckpoint.PERSISTENT: return - if isinstance(checkpoint.checkpoint_dir_or_data, (str, bytes, os.PathLike)): - if os.path.isfile(checkpoint.checkpoint_dir_or_data): - os.remove(checkpoint.checkpoint_dir_or_data) + if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): + if os.path.isfile(checkpoint.dir_or_data): + os.remove(checkpoint.dir_or_data) return - elif os.path.isdir(checkpoint.checkpoint_dir_or_data): - shutil.rmtree(checkpoint.checkpoint_dir_or_data) + elif os.path.isdir(checkpoint.dir_or_data): + shutil.rmtree(checkpoint.dir_or_data) return logger.warning( f"Could not delete checkpoint {checkpoint} from disk as it is " - f"neither file not directory. Path: {checkpoint.checkpoint_dir_or_data}." + f"neither file not directory. Path: {checkpoint.dir_or_data}." ) @@ -136,6 +136,18 @@ def __post_init__(self): class CheckpointManager: + """Common checkpoint management and bookkeeping class for Ray Train and Tune. + + This class acts as the common core for checkpoint bookkeeping in Ray ML libraries. + On a high level, this manager keeps a reference to all stored checkpoints + (both in-memory and on-disk checkpoints). For on-disk checkpoints, it + keeps a configured number of checkpoints according to specified metrics. + + The manager supports lazy data writing by utilizing the + ``_TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + should be persisted to disk. + """ + def __init__( self, checkpoint_strategy: CheckpointStrategy, @@ -229,7 +241,7 @@ def _get_checkpoint_score( return ( not is_nan(checkpoint_score), checkpoint_score if not is_nan(checkpoint_score) else 0, - checkpoint.checkpoint_id, + checkpoint.id, ) def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): @@ -258,10 +270,7 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): ).tracked_checkpoint # Only remove if checkpoint data is different - if ( - worst_checkpoint.checkpoint_dir_or_data - != checkpoint.checkpoint_dir_or_data - ): + if worst_checkpoint.dir_or_data != checkpoint.dir_or_data: self._maybe_delete_persisted_checkpoint(worst_checkpoint) logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") @@ -296,6 +305,9 @@ def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): def _get_next_checkpoint_path(self) -> Optional[Path]: return None + def __del__(self): + self._cleanup_checkpoints() + def __getstate__(self): state = self.__dict__.copy() @@ -304,7 +316,7 @@ def __getstate__(self): # Avoid serializing the memory checkpoint. state["_newest_memory_checkpoint"] = _TrackedCheckpoint( - checkpoint_dir_or_data=None, + dir_or_data=None, checkpoint_id=0, storage_mode=_TrackedCheckpoint.MEMORY, ) diff --git a/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py b/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py index 8e8a0ab88552..fdb1ab856f12 100644 --- a/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py +++ b/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py @@ -86,7 +86,7 @@ def plot_model_weights(means, covs, ax): # Restore trainer from checkpoint trial = analysis.trials[0] trainer = BanditLinTSTrainer(config=config) - trainer.restore(trial.checkpoint.checkpoint_dir_or_data) + trainer.restore(trial.checkpoint.dir_or_data) # Get model to plot arm weights distribution model = trainer.get_policy().model From e566c8f6107c338e5f1aec04bf414d7ef022f751 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 6 May 2022 12:17:39 +0100 Subject: [PATCH 26/81] Update python/ray/train/checkpoint.py Co-authored-by: Antoni Baum --- python/ray/train/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 896c6f2db8d1..bc0fbd8c6714 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -51,7 +51,7 @@ def __init__(self, *args, **kwargs): self.dir_or_data = None @property - def committed(self): + def committed(self) -> bool: return not self._data_to_commit def commit(self, path: Optional[Path] = None): From 3b727638368e82658a2de3cd74f021f02f461eec Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 May 2022 10:20:06 +0200 Subject: [PATCH 27/81] Default checkpoint score attr to None --- python/ray/train/checkpoint.py | 14 ++++++++------ python/ray/tune/checkpoint_manager.py | 5 ++++- python/ray/util/ml_utils/checkpoint_manager.py | 7 ++++--- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index bc0fbd8c6714..5ca27487adef 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -12,7 +12,6 @@ ) from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.tune.result import TRAINING_ITERATION as TUNE_TRAINING_ITERATION from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, _TrackedCheckpoint, @@ -122,6 +121,12 @@ def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): super().__init__(checkpoint_strategy=checkpoint_strategy) + self._validate_checkpoint_strategy() + + def _validate_checkpoint_strategy(self): + if self._checkpoint_strategy.checkpoint_score_attribute is None: + self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP + def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]] ) -> Optional[Dict]: @@ -191,13 +196,10 @@ def on_start_training( latest_checkpoint_id: Optional[int] = 0, ): checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self._checkpoint_strategy = checkpoint_strategy - # We only want to support one CheckpointStrategy object. Thus, - # for Ray Train we update the default score attribute for Ray Train - if checkpoint_strategy.checkpoint_score_attribute == TUNE_TRAINING_ITERATION: - checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP + self._validate_checkpoint_strategy() - self._checkpoint_strategy = checkpoint_strategy self.run_dir = run_dir self._latest_checkpoint_id = latest_checkpoint_id or 0 diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index d5cde13a0ac2..577ada4d3a83 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -2,6 +2,7 @@ import logging from typing import Callable, Optional +from ray.tune.result import TRAINING_ITERATION from ray.util.ml_utils.checkpoint_manager import ( CheckpointStrategy, MIN, @@ -30,7 +31,7 @@ class CheckpointManager(CommonCheckpointManager): def __init__( self, keep_checkpoints_num: int, - checkpoint_score_attr: str, + checkpoint_score_attr: Optional[str], delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): if keep_checkpoints_num == 0: @@ -39,6 +40,8 @@ def __init__( "to be None or a number greater than 0" ) + checkpoint_score_attr = checkpoint_score_attr or TRAINING_ITERATION + checkpoint_score_desc = checkpoint_score_attr.startswith("min-") if checkpoint_score_desc: checkpoint_score_attr = checkpoint_score_attr[4:] diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index c8c4ed1c5ce6..4e163f2151e4 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -10,7 +10,7 @@ from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray -from ray.tune.result import NODE_IP, TRAINING_ITERATION +from ray.tune.result import NODE_IP from ray.util import PublicAPI from ray.util.ml_utils.util import is_nan @@ -110,7 +110,8 @@ class CheckpointStrategy: score checkpoints to determine which checkpoints should be kept on disk when there are greater than ``num_to_keep`` checkpoints. This attribute must be a key from the checkpoint - dictionary which has a numerical value. + dictionary which has a numerical value. Per default, the last + checkpoints will be kept. checkpoint_score_order (str). Either "max" or "min". If "max", then checkpoints with highest values of ``checkpoint_score_attribute`` will be kept. @@ -119,7 +120,7 @@ class CheckpointStrategy: """ num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TRAINING_ITERATION + checkpoint_score_attribute: Optional[str] = None checkpoint_score_order: str = MAX def __post_init__(self): From a736f07ae71c3985cf122de356ceac0322d4dfaf Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 May 2022 10:25:49 +0200 Subject: [PATCH 28/81] _TrackedCheckpoint -> TrackedCheckpoint --- python/ray/train/checkpoint.py | 16 +++-- python/ray/tune/callback.py | 4 +- python/ray/tune/checkpoint_manager.py | 22 +++---- python/ray/tune/ray_trial_executor.py | 14 ++--- python/ray/tune/schedulers/pbt.py | 6 +- python/ray/tune/syncer.py | 8 +-- .../ray/tune/tests/test_checkpoint_manager.py | 62 +++++++++--------- .../ray/tune/tests/test_ray_trial_executor.py | 6 +- python/ray/tune/tests/test_trial_runner_2.py | 6 +- .../tune/tests/test_trial_runner_callbacks.py | 6 +- python/ray/tune/tests/test_trial_scheduler.py | 14 ++--- .../tune/tests/test_trial_scheduler_pbt.py | 6 +- .../test_trial_scheduler_resource_changing.py | 6 +- python/ray/tune/trial.py | 12 ++-- python/ray/tune/trial_executor.py | 8 +-- python/ray/tune/trial_runner.py | 6 +- .../ray/util/ml_utils/checkpoint_manager.py | 63 ++++++++++--------- 17 files changed, 134 insertions(+), 131 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 5ca27487adef..ec89e0ad6076 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -14,7 +14,7 @@ from ray.train.utils import construct_path from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, - _TrackedCheckpoint, + TrackedCheckpoint, CheckpointStrategy, ) @@ -35,7 +35,7 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -class _NotYetPersistedCheckpoint(_TrackedCheckpoint): +class _NotYetPersistedCheckpoint(TrackedCheckpoint): """Tracked checkpoint that is not yet persisted to disk. This checkpoint class supports lazy writing. The checkpoint manager will @@ -69,18 +69,16 @@ def commit(self, path: Optional[Path] = None): self.dir_or_data = path self._data_to_commit = None - def delete( - self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None - ): + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): if not self.committed: return return super().delete(delete_fn=delete_fn) @classmethod - def from_tracked_checkpoint(cls, checkpoint: _TrackedCheckpoint): + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): new_checkpoint = cls( dir_or_data=checkpoint.dir_or_data, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, checkpoint_id=checkpoint.id, result=checkpoint.result, node_ip=checkpoint.node_ip, @@ -165,10 +163,10 @@ def _process_checkpoint( f"train.save_checkpoint." ) - tracked_checkpoint = _TrackedCheckpoint( + tracked_checkpoint = TrackedCheckpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index 7b2989c03e99..3539a493c300 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -3,7 +3,7 @@ import warnings from ray.util.annotations import PublicAPI -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -245,7 +245,7 @@ def on_checkpoint( iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: _TrackedCheckpoint, + checkpoint: TrackedCheckpoint, **info, ): """Called after a trial saved a checkpoint with Tune. diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 577ada4d3a83..aa08d9188bc2 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -8,7 +8,7 @@ MIN, MAX, CheckpointManager as CommonCheckpointManager, - _TrackedCheckpoint, + TrackedCheckpoint, ) logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def __init__( self, keep_checkpoints_num: int, checkpoint_score_attr: Optional[str], - delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, + delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, ): if keep_checkpoints_num == 0: raise RuntimeError( @@ -56,23 +56,23 @@ def __init__( super().__init__(checkpoint_strategy=checkpoint_strategy, delete_fn=delete_fn) - def on_checkpoint(self, checkpoint: _TrackedCheckpoint): + def on_checkpoint(self, checkpoint: TrackedCheckpoint): # Set checkpoint ID checkpoint.id = checkpoint.id or self._latest_checkpoint_id self._latest_checkpoint_id += 1 - if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) else: - assert checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + assert checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT assert ( self._checkpoint_strategy.num_to_keep is None or self._checkpoint_strategy.num_to_keep > 0 ) self._decide_what_to_do_with_checkpoint(checkpoint) - def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): - assert persisted_checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + assert persisted_checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT super()._skip_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) # Ray Tune always keeps track of the latest persisted checkpoint. # Note that this checkpoint will be deleted once it is not the @@ -85,10 +85,10 @@ def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): @property def newest_persistent_checkpoint(self): - return self._latest_persisted_checkpoint or _TrackedCheckpoint( + return self._latest_persisted_checkpoint or TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, ) @property @@ -102,10 +102,10 @@ def newest_checkpoint(self): @property def newest_memory_checkpoint(self): - return self._latest_memory_checkpoint or _TrackedCheckpoint( + return self._latest_memory_checkpoint or TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, ) def best_checkpoints(self): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 79ea74e13cc7..cdf24a982c8f 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -38,7 +38,7 @@ from ray.tune.utils.resource_updater import ResourceUpdater from ray.util import log_once from ray.util.annotations import DeveloperAPI -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint from ray.util.placement_group import remove_placement_group, PlacementGroup logger = logging.getLogger(__name__) @@ -672,9 +672,9 @@ def force_reconcilation_on_next_step_end(self) -> None: def save( self, trial: Trial, - storage: str = _TrackedCheckpoint.PERSISTENT, + storage: str = TrackedCheckpoint.PERSISTENT, result: Optional[Dict] = None, - ) -> _TrackedCheckpoint: + ) -> TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. Args: @@ -690,15 +690,15 @@ def save( logger.debug(f"saving trial {trial}") result = result or trial.last_result with self._change_working_directory(trial): - if storage == _TrackedCheckpoint.MEMORY: + if storage == TrackedCheckpoint.MEMORY: value = trial.runner.save_to_object.remote() - checkpoint = _TrackedCheckpoint( + checkpoint = TrackedCheckpoint( dir_or_data=value, storage_mode=storage, result=result ) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() - checkpoint = _TrackedCheckpoint( + checkpoint = TrackedCheckpoint( dir_or_data=value, storage_mode=storage, result=result ) trial.saving_to = checkpoint @@ -725,7 +725,7 @@ def restore(self, trial: Trial) -> None: ) value = checkpoint.dir_or_data node_ip = checkpoint.node_ip - if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index e341a2b8ea31..ba1cda338296 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -18,7 +18,7 @@ from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial from ray.util.debug import log_once -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint logger = logging.getLogger(__name__) @@ -529,7 +529,7 @@ def _checkpoint_or_exploit( state.last_checkpoint = trial.checkpoint else: state.last_checkpoint = trial_executor.save( - trial, _TrackedCheckpoint.MEMORY, result=state.last_result + trial, TrackedCheckpoint.MEMORY, result=state.last_result ) self._num_checkpoints += 1 else: @@ -873,7 +873,7 @@ def on_trial_result( ) checkpoint = trial_runner.trial_executor.save( - trial, _TrackedCheckpoint.MEMORY, result=result + trial, TrackedCheckpoint.MEMORY, result=result ) new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config) diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index 4a117d1af72f..e97116036b4f 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -37,7 +37,7 @@ RemoteTaskClient, ) from ray.util.annotations import PublicAPI -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -525,8 +525,8 @@ def _create_trial_syncer(self, trial: "Trial"): def _remove_trial_syncer(self, trial: "Trial"): self._syncers.pop(trial, None) - def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint): - if checkpoint.storage_mode == _TrackedCheckpoint.MEMORY: + def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: TrackedCheckpoint): + if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: return trial_syncer = self._get_trial_syncer(trial) @@ -617,7 +617,7 @@ def on_checkpoint( iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: _TrackedCheckpoint, + checkpoint: TrackedCheckpoint, **info, ): self._sync_trial_checkpoint(trial, checkpoint) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index b5d8a838e7b9..cb7d414a4056 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -8,7 +8,7 @@ from ray.tune.result import TRAINING_ITERATION from ray.tune.checkpoint_manager import CheckpointManager -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, logger +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, logger class CheckpointManagerTest(unittest.TestCase): @@ -25,15 +25,15 @@ def checkpoint_manager(self, keep_checkpoints_num): def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - memory_checkpoint = _TrackedCheckpoint( + memory_checkpoint = TrackedCheckpoint( dir_or_data={0}, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) - persistent_checkpoint = _TrackedCheckpoint( + persistent_checkpoint = TrackedCheckpoint( dir_or_data={1}, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(1, 1), ) checkpoint_manager.on_checkpoint(persistent_checkpoint) @@ -49,9 +49,9 @@ def testOnCheckpointOrdered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data={i}, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) for i in range(3) @@ -81,9 +81,9 @@ def testOnCheckpointUnordered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data={i}, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) for i in range(3, -1, -1) @@ -115,9 +115,9 @@ def testBestCheckpoints(self): """ keep_checkpoints_num = 4 checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=i, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(i, i), ) for i in range(8) @@ -140,16 +140,16 @@ def testBestCheckpointsWithNan(self): """ keep_checkpoints_num = 2 checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=None, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=3, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(0, 3), ) ] @@ -172,9 +172,9 @@ def testBestCheckpointsOnlyNan(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=i, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(float("nan"), i), ) for i in range(4) @@ -196,9 +196,9 @@ def testOnCheckpointUnavailableAttribute(self): """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - no_attr_checkpoint = _TrackedCheckpoint( + no_attr_checkpoint = TrackedCheckpoint( dir_or_data=0, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result={}, ) @@ -212,14 +212,14 @@ def testOnCheckpointUnavailableAttribute(self): def testOnMemoryCheckpoint(self): checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=0, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ), - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=0, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result=self.mock_result(0, 0), ), ] @@ -246,24 +246,24 @@ def testSameCheckpoint(self): tmpfiles.append(tmpfile) checkpoints = [ - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=tmpfiles[0], - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(5, 5), ), - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=tmpfiles[1], - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(10, 10), ), - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=tmpfiles[2], - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(0, 0), ), - _TrackedCheckpoint( + TrackedCheckpoint( dir_or_data=tmpfiles[1], - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=self.mock_result(20, 20), ), ] diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 4c064f0e6af1..12053447e2cf 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -24,7 +24,7 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory from unittest.mock import patch -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint class TrialExecutorInsufficientResourcesTest(unittest.TestCase): @@ -120,7 +120,7 @@ def _simulate_getting_result(self, trial): trial.update_last_result(training_result) def _simulate_saving(self, trial): - checkpoint = self.trial_executor.save(trial, _TrackedCheckpoint.PERSISTENT) + checkpoint = self.trial_executor.save(trial, TrackedCheckpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.dir_or_data, None) event = self.trial_executor.get_next_executor_event( @@ -189,7 +189,7 @@ def testSavePauseResumeErrorRestore(self): # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) - self.assertEqual(trial.checkpoint.storage_mode, _TrackedCheckpoint.MEMORY) + self.assertEqual(trial.checkpoint.storage_mode, TrackedCheckpoint.MEMORY) # Resume self._simulate_starting_trial(trial) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index aae34ea076c7..4d69c94e46d0 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -18,7 +18,7 @@ from ray.tune.suggest import BasicVariantGenerator from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver from ray.tune.utils.trainable import TrainableUtil -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint def create_mock_components(): @@ -333,9 +333,9 @@ def write_checkpoint(trial: Trial, index: int): with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) - tune_cp = _TrackedCheckpoint( + tune_cp = TrackedCheckpoint( dir_or_data=checkpoint_dir, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=result, ) trial.saving_to = tune_cp diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 986f409fa67a..7f9d4b97383f 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -25,7 +25,7 @@ from ray.tune import Callback from ray.tune.utils.callback import create_default_callbacks from ray.tune.experiment import Experiment -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint class TestCallback(Callback): @@ -150,9 +150,9 @@ def testCallbackSteps(self): self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. - cp = _TrackedCheckpoint( + cp = TrackedCheckpoint( dir_or_data=ray.put(1), - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 5b7be548910d..98e1c93d519c 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -33,7 +33,7 @@ from ray.tune.resources import Resources from ray.rllib import _register_all -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint _register_all() @@ -249,10 +249,10 @@ def stop_trial(self, trial, error=False, error_msg=None): def restore(self, trial, checkpoint=None, block=False): pass - def save(self, trial, type=_TrackedCheckpoint.PERSISTENT, result=None): - return _TrackedCheckpoint( + def save(self, trial, type=TrackedCheckpoint.PERSISTENT, result=None): + return TrackedCheckpoint( dir_or_data=trial.trainable_name, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, result=result, ) @@ -312,7 +312,7 @@ def get_live_trials(self): return {t for t in self.trials if t.status != Trial.TERMINATED} def _pause_trial(self, trial): - self.trial_executor.save(trial, _TrackedCheckpoint.MEMORY, None) + self.trial_executor.save(trial, TrackedCheckpoint.MEMORY, None) trial.status = Trial.PAUSED def _launch_trial(self, trial): @@ -847,9 +847,9 @@ def on_checkpoint(self, checkpoint): @property def checkpoint(self): - return _TrackedCheckpoint( + return TrackedCheckpoint( dir_or_data=self.trainable_name, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result=None, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 8f25b214b023..f96e98cfe803 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -20,7 +20,7 @@ # Import psutil after ray so the packaged version is used. import psutil -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint MB = 1024 ** 2 @@ -440,9 +440,9 @@ def testBurnInPeriod(self): class MockTrial(Trial): @property def checkpoint(self): - return _TrackedCheckpoint( + return TrackedCheckpoint( dir_or_data="None", - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result={}, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index 035069b8fce3..43fedd3ce095 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -8,7 +8,7 @@ DistributeResources, DistributeResourcesToTopJob, ) -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint class MockResourceUpdater: @@ -49,9 +49,9 @@ def get_trials(self): class MockTrial(Trial): @property def checkpoint(self): - return _TrackedCheckpoint( + return TrackedCheckpoint( dir_or_data="None", - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, result={}, ) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 633b59b5a546..30684269f0f0 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -40,7 +40,7 @@ from ray.tune.utils import date_str, flatten_dict from ray.util.annotations import DeveloperAPI from ray._private.utils import binary_to_hex, hex_to_binary -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint DEBUG_PRINT_INTERVAL = 5 logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ def __init__(self, trial_id, runner): self.trial_id = trial_id self.runner = runner - def __call__(self, checkpoint: _TrackedCheckpoint): + def __call__(self, checkpoint: TrackedCheckpoint): """Requests checkpoint deletion asynchronously. Args: @@ -110,7 +110,7 @@ def __call__(self, checkpoint: _TrackedCheckpoint): return if ( - checkpoint.storage_mode == _TrackedCheckpoint.PERSISTENT + checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT and checkpoint.dir_or_data ): checkpoint_path = checkpoint.dir_or_data @@ -467,9 +467,9 @@ def checkpoint(self): else: checkpoint = self.checkpoint_manager.newest_checkpoint if checkpoint.dir_or_data is None: - checkpoint = _TrackedCheckpoint( + checkpoint = TrackedCheckpoint( dir_or_data=self.restore_path, - storage_mode=_TrackedCheckpoint.PERSISTENT, + storage_mode=TrackedCheckpoint.PERSISTENT, ) return checkpoint @@ -655,7 +655,7 @@ def clear_checkpoint(self): self.restoring_from = None self.invalidate_json_state() - def on_checkpoint(self, checkpoint: _TrackedCheckpoint): + def on_checkpoint(self, checkpoint: TrackedCheckpoint): """Hook for handling checkpoints taken by the Trainable. Args: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index e391864a81dc..34bf96d7c5d6 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -7,7 +7,7 @@ from ray.tune import TuneError from ray.util.annotations import DeveloperAPI from ray.tune.trial import Trial -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint logger = logging.getLogger(__name__) @@ -124,7 +124,7 @@ def pause_trial(self, trial: Trial) -> None: """ assert trial.status == Trial.RUNNING, trial.status try: - self.save(trial, _TrackedCheckpoint.MEMORY) + self.save(trial, TrackedCheckpoint.MEMORY) self.stop_trial(trial) self.set_status(trial, Trial.PAUSED) except Exception: @@ -194,9 +194,9 @@ def restore(self, trial: Trial) -> None: def save( self, trial: Trial, - storage: str = _TrackedCheckpoint.PERSISTENT, + storage: str = TrackedCheckpoint.PERSISTENT, result: Optional[Dict] = None, - ) -> _TrackedCheckpoint: + ) -> TrackedCheckpoint: """Saves training state of this trial to a checkpoint. If result is None, this trial's last result will be used. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 65e70085d475..bd8c09ce3437 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -42,7 +42,7 @@ from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder from ray.tune.web_server import TuneServer from ray.util.debug import log_once -from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint MAX_DEBUG_TRIALS = 20 @@ -1117,7 +1117,7 @@ def _process_trial_save( checkpoint=trial.saving_to, ) trial.on_checkpoint(trial.saving_to) - if trial.checkpoint.storage_mode != _TrackedCheckpoint.MEMORY: + if trial.checkpoint.storage_mode != TrackedCheckpoint.MEMORY: self.trial_executor.mark_trial_to_checkpoint(trial) except Exception: logger.exception( @@ -1204,7 +1204,7 @@ def _checkpoint_trial_if_needed(self, trial, force=False): if trial.should_checkpoint() or force: # Save trial runtime if possible. if trial.runner: - self.trial_executor.save(trial, storage=_TrackedCheckpoint.PERSISTENT) + self.trial_executor.save(trial, storage=TrackedCheckpoint.PERSISTENT) def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]): """Tries to recover trial. diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 4e163f2151e4..5b8845f8b002 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -12,6 +12,7 @@ import ray from ray.tune.result import NODE_IP from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI from ray.util.ml_utils.util import is_nan MAX = "max" @@ -20,7 +21,15 @@ logger = logging.getLogger(__name__) -class _TrackedCheckpoint: +@DeveloperAPI +class TrackedCheckpoint: + """Checkpoint tracked by a checkpoint manager. + + This class is used to track checkpoints generated by trainables and trainers in + order to add metadata (e.g. the result, or the node where it has been created) + and for bookkeeping purposes. + """ + MEMORY = "memory" PERSISTENT = "persistent" @@ -44,25 +53,23 @@ def commit(self, path: Optional[Path] = None): """Commit checkpoint to disk, if needed.""" pass - def delete( - self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None - ): + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): """Delete checkpoint from disk, if needed.""" delete_fn = delete_fn or _default_delete_fn delete_fn(self) def __repr__(self): - if self.storage_mode == _TrackedCheckpoint.MEMORY: - return f"<_TrackedCheckpoint storage='MEMORY' result={self.result}>" + if self.storage_mode == TrackedCheckpoint.MEMORY: + return f"" return ( - f"<_TrackedCheckpoint storage='PERSISTENT' " + f"" ) -def _default_delete_fn(checkpoint: _TrackedCheckpoint): - if checkpoint.storage_mode != _TrackedCheckpoint.PERSISTENT: +def _default_delete_fn(checkpoint: TrackedCheckpoint): + if checkpoint.storage_mode != TrackedCheckpoint.PERSISTENT: return if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): @@ -79,7 +86,7 @@ def _default_delete_fn(checkpoint: _TrackedCheckpoint): class _HeapCheckpointWrapper: - def __init__(self, priority: Any, tracked_checkpoint: _TrackedCheckpoint): + def __init__(self, priority: Any, tracked_checkpoint: TrackedCheckpoint): self.priority = priority self.tracked_checkpoint = tracked_checkpoint @@ -145,7 +152,7 @@ class CheckpointManager: keeps a configured number of checkpoints according to specified metrics. The manager supports lazy data writing by utilizing the - ``_TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint should be persisted to disk. """ @@ -153,7 +160,7 @@ def __init__( self, checkpoint_strategy: CheckpointStrategy, latest_checkpoint_id: int = 0, - delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, + delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, ): self._checkpoint_strategy = checkpoint_strategy @@ -165,27 +172,25 @@ def __init__( # Best checkpoint altogether. # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint: Optional[_TrackedCheckpoint] = None - self._latest_persisted_checkpoint: Optional[_TrackedCheckpoint] = None - self._latest_memory_checkpoint: Optional[_TrackedCheckpoint] = None + self._best_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_memory_checkpoint: Optional[TrackedCheckpoint] = None # Checkpoints that are not immediately removed self._checkpoints_to_clean_up = set() self._delete_fn = delete_fn - def set_delete_fn( - self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] - ): + def set_delete_fn(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]]): self._delete_fn = delete_fn - def _replace_latest_memory_checkpoint(self, memory_checkpoint: _TrackedCheckpoint): - assert memory_checkpoint.storage_mode == _TrackedCheckpoint.MEMORY + def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): + assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY self._latest_memory_checkpoint = memory_checkpoint # Avoid memory leaks on k8s pods gc.collect() def _replace_latest_persisted_checkpoint( - self, persisted_checkpoint: _TrackedCheckpoint + self, persisted_checkpoint: TrackedCheckpoint ): second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint self._latest_persisted_checkpoint = persisted_checkpoint @@ -196,7 +201,7 @@ def _replace_latest_persisted_checkpoint( ) def _maybe_replace_best_persisted_checkpoint( - self, persisted_checkpoint: _TrackedCheckpoint + self, persisted_checkpoint: TrackedCheckpoint ): if self._best_persisted_checkpoint is None: self._best_persisted_checkpoint = persisted_checkpoint @@ -207,7 +212,7 @@ def _maybe_replace_best_persisted_checkpoint( self._best_persisted_checkpoint = persisted_checkpoint def _get_checkpoint_score( - self, checkpoint: _TrackedCheckpoint + self, checkpoint: TrackedCheckpoint ) -> Tuple[bool, numbers.Number, int]: checkpoint_score_attribute = ( self._checkpoint_strategy.checkpoint_score_attribute @@ -245,7 +250,7 @@ def _get_checkpoint_score( checkpoint.id, ) - def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): checkpoint_score = self._get_checkpoint_score(checkpoint) wrapped_checkpoint = _HeapCheckpointWrapper( priority=checkpoint_score, tracked_checkpoint=checkpoint @@ -284,14 +289,14 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: _TrackedCheckpoint): self._cleanup_checkpoints() def _maybe_delete_persisted_checkpoint( - self, persisted_checkpoint: _TrackedCheckpoint + self, persisted_checkpoint: TrackedCheckpoint ): if persisted_checkpoint == self._latest_persisted_checkpoint: self._checkpoints_to_clean_up.add(persisted_checkpoint) else: self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) - def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + def _delete_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): persisted_checkpoint.delete(delete_fn=self._delete_fn) self._checkpoints_to_clean_up.discard(persisted_checkpoint) @@ -299,7 +304,7 @@ def _cleanup_checkpoints(self): for checkpoint in list(self._checkpoints_to_clean_up): self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) - def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") self._checkpoints_to_clean_up.add(persisted_checkpoint) @@ -316,10 +321,10 @@ def __getstate__(self): state.pop("_delete_fn", None) # Avoid serializing the memory checkpoint. - state["_newest_memory_checkpoint"] = _TrackedCheckpoint( + state["_newest_memory_checkpoint"] = TrackedCheckpoint( dir_or_data=None, checkpoint_id=0, - storage_mode=_TrackedCheckpoint.MEMORY, + storage_mode=TrackedCheckpoint.MEMORY, ) return state From 3081baeefed7182433cd1b02177ef60fe5903269 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 May 2022 10:25:49 +0200 Subject: [PATCH 29/81] _TrackedCheckpoint -> TrackedCheckpoint --- python/ray/util/ml_utils/checkpoint_manager.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 5b8845f8b002..6496e47f618f 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -12,7 +12,6 @@ import ray from ray.tune.result import NODE_IP from ray.util import PublicAPI -from ray.util.annotations import DeveloperAPI from ray.util.ml_utils.util import is_nan MAX = "max" @@ -21,15 +20,7 @@ logger = logging.getLogger(__name__) -@DeveloperAPI class TrackedCheckpoint: - """Checkpoint tracked by a checkpoint manager. - - This class is used to track checkpoints generated by trainables and trainers in - order to add metadata (e.g. the result, or the node where it has been created) - and for bookkeeping purposes. - """ - MEMORY = "memory" PERSISTENT = "persistent" From ed5d68d1b853461fa68c8722d4723c52f3d63286 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 May 2022 11:57:09 +0200 Subject: [PATCH 30/81] Adapt changes --- python/ray/util/ml_utils/checkpoint_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 6496e47f618f..e05c14e0a3d9 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -212,7 +212,7 @@ def _get_checkpoint_score( logger.error( f"Result dict has no key: {checkpoint_score_attribute}. " f"checkpoint_score_attr must be set to a key in the " - f"result dict." + f"result dict. Valid keys are: {list(checkpoint.result.keys())}" ) checkpoint_result = float("-inf") else: From 9159bd6cf5cbcac74afeddc90a286d73138340d9 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 May 2022 15:46:44 +0200 Subject: [PATCH 31/81] Default checkpoint strategy --- python/ray/util/ml_utils/checkpoint_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index e05c14e0a3d9..72724e075db6 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -153,7 +153,7 @@ def __init__( latest_checkpoint_id: int = 0, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, ): - self._checkpoint_strategy = checkpoint_strategy + self._checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() # Incremental unique checkpoint ID of this run. self._latest_checkpoint_id = latest_checkpoint_id From 27e27abb4cfe723312ca98f3eab38cf34f5dd8d7 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 12 May 2022 10:19:57 +0100 Subject: [PATCH 32/81] Error handling for delete fn --- python/ray/util/ml_utils/checkpoint_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 72724e075db6..dc04cad6eb51 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -47,7 +47,10 @@ def commit(self, path: Optional[Path] = None): def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): """Delete checkpoint from disk, if needed.""" delete_fn = delete_fn or _default_delete_fn - delete_fn(self) + try: + delete_fn(self) + except Exception as e: + logger.warning(f"Checkpoint deletion failed: {e}") def __repr__(self): if self.storage_mode == TrackedCheckpoint.MEMORY: @@ -70,7 +73,7 @@ def _default_delete_fn(checkpoint: TrackedCheckpoint): elif os.path.isdir(checkpoint.dir_or_data): shutil.rmtree(checkpoint.dir_or_data) return - logger.warning( + raise RuntimeError( f"Could not delete checkpoint {checkpoint} from disk as it is " f"neither file not directory. Path: {checkpoint.dir_or_data}." ) From fb2d95e677d86ae22b88c1d55301aeadedc9c53b Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:38:45 +0100 Subject: [PATCH 33/81] Main entrypoint --- python/ray/train/checkpoint.py | 9 +-- python/ray/tune/checkpoint_manager.py | 7 ++- .../ray/util/ml_utils/checkpoint_manager.py | 59 +++++++++++++++++-- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ec89e0ad6076..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -142,8 +142,7 @@ def _process_checkpoint( checkpoint_results: List[TrainingResult], decode_checkpoint_fn: Callable, ) -> None: - """Perform all processing for a checkpoint.""" - + """Ray Train entrypoint. Perform all processing for a checkpoint.""" # Get checkpoint from first worker. checkpoint_data = checkpoint_results[0].data @@ -169,14 +168,16 @@ def _process_checkpoint( storage_mode=TrackedCheckpoint.MEMORY, result={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) + self.register_checkpoint(checkpoint=tracked_checkpoint) + def register_checkpoint(self, checkpoint: TrackedCheckpoint): # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(tracked_checkpoint) + self._replace_latest_memory_checkpoint(checkpoint) # Only process further if we consider keeping this checkpoint on disk if self._checkpoint_strategy.num_to_keep != 0: not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(tracked_checkpoint) + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) ) self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index aa08d9188bc2..7e04adf4b006 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -56,7 +56,7 @@ def __init__( super().__init__(checkpoint_strategy=checkpoint_strategy, delete_fn=delete_fn) - def on_checkpoint(self, checkpoint: TrackedCheckpoint): + def handle_checkpoint(self, checkpoint: TrackedCheckpoint): # Set checkpoint ID checkpoint.id = checkpoint.id or self._latest_checkpoint_id self._latest_checkpoint_id += 1 @@ -71,6 +71,11 @@ def on_checkpoint(self, checkpoint: TrackedCheckpoint): ) self._decide_what_to_do_with_checkpoint(checkpoint) + def on_checkpoint(self, checkpoint: TrackedCheckpoint): + """Ray Tune's entrypoint""" + # Todo (krfricke): Replace with handle_checkpoint. + self.handle_checkpoint(checkpoint) + def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): assert persisted_checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT super()._skip_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index dc04cad6eb51..f97068b58377 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -12,6 +12,7 @@ import ray from ray.tune.result import NODE_IP from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI from ray.util.ml_utils.util import is_nan MAX = "max" @@ -20,7 +21,25 @@ logger = logging.getLogger(__name__) +@DeveloperAPI class TrackedCheckpoint: + """Checkpoint tracked by a checkpoint manager. + + This class is used to track checkpoints generated by trainables and trainers in + order to add metadata (e.g. the result, or the node where it has been created) + and for bookkeeping purposes. + + Args: + dir_or_data: Checkpoint directory, checkpoint data, or a future to either. + storage_mode: Either MEMORY or PERSISTENT. + checkpoint_id: Checkpoint number. Usually this should be monotonically + increasing for each tracked checkpoint. + result: Observed metrics for this checkpoint. This is used to determine + the value of the ``checkpoint_score_attr``. + node_ip: IP of the node where the checkpoint was generated. Defaults + to the current node. + """ + MEMORY = "memory" PERSISTENT = "persistent" @@ -40,12 +59,23 @@ def __init__( self.result = result or {} self.node_ip = node_ip or self.result.get(NODE_IP, None) - def commit(self, path: Optional[Path] = None): - """Commit checkpoint to disk, if needed.""" + def commit(self, path: Optional[Path] = None) -> None: + """Commit checkpoint to disk, if needed. + + Args: + path: Path to commit checkpoint to. + """ pass - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - """Delete checkpoint from disk, if needed.""" + def delete( + self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None + ) -> None: + """Delete checkpoint from disk, if needed. + + Args: + delete_fn: Function to be called with the tracked checkpoint as an + argument. Defaults to removing the local directory/file. + """ delete_fn = delete_fn or _default_delete_fn try: delete_fn(self) @@ -175,8 +205,29 @@ def __init__( self._delete_fn = delete_fn def set_delete_fn(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]]): + """Update the function called to delete persisted checkpoints. + + Args: + delete_fn: Function that takes a tracked checkpoint as an argument and + deletes it from disk. + """ self._delete_fn = delete_fn + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + """Register new checkpoint and add to bookkeeping. + + This method will register a new checkpoint and add it to the internal + bookkeeping logic. This means the checkpoint manager will decide if + this checkpoint should be kept, and if older or worse performing + checkpoints should be deleted. + + Subclasses have to implement this method. + + Args: + checkpoint: Tracked checkpoint object to add to bookkeeping. + """ + raise NotImplementedError + def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY self._latest_memory_checkpoint = memory_checkpoint From 29ad103b293f6bd36c15e149b2be5cdbbe929cf0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:34:38 +0100 Subject: [PATCH 34/81] Add general entrypoint --- dashboard/modules/job/job_manager.py | 2 +- dashboard/optional_utils.py | 2 +- python/ray/_private/gcs_pubsub.py | 6 +- python/ray/tests/test_async.py | 2 +- python/ray/tests/test_client_proxy.py | 2 +- .../ray/tune/analysis/experiment_analysis.py | 2 +- python/ray/tune/checkpoint_manager.py | 2 +- python/ray/tune/examples/tf_mnist_example.py | 8 +- python/ray/tune/trial.py | 2 +- python/ray/util/client/server/proxier.py | 2 +- python/ray/util/client/worker.py | 2 +- python/ray/util/dask/scheduler.py | 6 +- .../ray/util/ml_utils/checkpoint_manager.py | 80 +++++++++++++++---- release/ray_release/anyscale_util.py | 4 +- release/ray_release/cluster_manager/full.py | 20 ++--- .../ray_release/cluster_manager/minimal.py | 10 +-- .../ray_release/command_runner/sdk_runner.py | 8 +- rllib/agents/trainer.py | 9 ++- 18 files changed, 112 insertions(+), 57 deletions(-) diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index a38a59e6e220..758a1ca57c92 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -260,7 +260,7 @@ async def run( # at the same time assert len(finished) == 1, "Should have only one coroutine done" [child_process_task] = finished - return_code = child_process_task.result() + return_code = child_process_task.metrics() if return_code == 0: self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED) else: diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 284cb8ef9c19..59908d9be11e 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -208,7 +208,7 @@ async def _cache_handler(*args) -> aiohttp.web.Response: def _update_cache(task): try: - response = task.result() + response = task.metrics() except Exception: response = rest_response( success=False, message=traceback.format_exc() diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 12bae7d64dfb..1aca9d986d29 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -246,7 +246,7 @@ def _poll_locked(self, timeout=None) -> None: try: # Use 1s timeout to check for subscriber closing # periodically. - fut.result(timeout=1) + fut.metrics(timeout=1) break except grpc.FutureTimeoutError: # Subscriber has closed. Cancel inflight request and @@ -262,8 +262,8 @@ def _poll_locked(self, timeout=None) -> None: raise if fut.done(): - self._last_batch_size = len(fut.result().pub_messages) - for msg in fut.result().pub_messages: + self._last_batch_size = len(fut.metrics().pub_messages) + for msg in fut.metrics().pub_messages: if msg.channel_type != self._channel: logger.warn(f"Ignoring message from unsubscribed channel {msg}") continue diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index ac39fe681ecd..c8f3b9c65300 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -137,7 +137,7 @@ def test_concurrent_future(ray_start_regular_shared): def cb(fut): nonlocal global_result - global_result = fut.result() + global_result = fut.metrics() fut.add_done_callback(cb) assert global_result == 1 diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index d0d1d2c307d1..5c34bcea173f 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -52,7 +52,7 @@ def test_proxy_manager_lifecycle(shutdown_only): pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server - grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) + grpc.channel_ready_future(pm.get_channel(client)).metrics(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == free_ports[0], f"Free Ports are: {free_ports}" diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 5d1bd3621f1e..ed0255365040 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -423,7 +423,7 @@ def get_trial_checkpoints_paths( # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [ - (c.value, unflattened_lookup(metric, c.result)) for c in checkpoints + (c.value, unflattened_lookup(metric, c.metrics)) for c in checkpoints ] else: raise ValueError("trial should be a string or a Trial instance.") diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 75cf4b8cb835..321f1a132d39 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -214,7 +214,7 @@ def best_checkpoints(self): return [queue_item.value for queue_item in checkpoints] def _priority(self, checkpoint): - result = flatten_dict(checkpoint.result) + result = flatten_dict(checkpoint.metrics) priority = result[self._checkpoint_score_attr] if self._checkpoint_score_desc: priority = -priority diff --git a/python/ray/tune/examples/tf_mnist_example.py b/python/ray/tune/examples/tf_mnist_example.py index e2bde13e4106..1468a18d6d90 100644 --- a/python/ray/tune/examples/tf_mnist_example.py +++ b/python/ray/tune/examples/tf_mnist_example.py @@ -109,10 +109,10 @@ def step(self): # It is important to return tf.Tensors as numpy objects. return { "epoch": self.iteration, - "loss": self.train_loss.result().numpy(), - "accuracy": self.train_accuracy.result().numpy() * 100, - "test_loss": self.test_loss.result().numpy(), - "mean_accuracy": self.test_accuracy.result().numpy() * 100, + "loss": self.train_loss.metrics().numpy(), + "accuracy": self.train_accuracy.metrics().numpy() * 100, + "test_loss": self.test_loss.metrics().numpy(), + "mean_accuracy": self.test_accuracy.metrics().numpy() * 100, } diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f793714262f2..466bbc606d0d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -660,7 +660,7 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): def on_restore(self): """Handles restoration completion.""" assert self.is_restoring - self.last_result = self.restoring_from.result + self.last_result = self.restoring_from.metrics self.restoring_from = None self.invalidate_json_state() diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index e83ccd82d6fa..eee6f07e98d2 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -370,7 +370,7 @@ def get_channel( # Wait for the SpecificServer to become ready. server.wait_ready() try: - grpc.channel_ready_future(server.channel).result( + grpc.channel_ready_future(server.channel).metrics( timeout=CHECK_CHANNEL_TIMEOUT_S ) return server.channel diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 1f44e0287503..bc2548db62fe 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -216,7 +216,7 @@ def _connect_channel(self, reconnecting=False) -> None: try: # Let gRPC wait for us to see if the channel becomes ready. # If it throws, we couldn't connect. - grpc.channel_ready_future(self.channel).result(timeout=timeout) + grpc.channel_ready_future(self.channel).metrics(timeout=timeout) # The HTTP2 channel is ready. Wrap the channel with the # RayletDriverStub, allowing for unary requests. self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 9c448474436c..469770e02435 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -442,7 +442,7 @@ def render_progress_bar(tracker, object_refs): from tqdm import tqdm # At this time, every task should be submitted. - total, finished = ray.get(tracker.result.remote()) + total, finished = ray.get(tracker.metrics.remote()) reported_finished_so_far = 0 pb_bar = tqdm(total=total, position=0) pb_bar.set_description("") @@ -450,7 +450,7 @@ def render_progress_bar(tracker, object_refs): ready_refs = [] while finished < total: - submitted, finished = ray.get(tracker.result.remote()) + submitted, finished = ray.get(tracker.metrics.remote()) pb_bar.update(finished - reported_finished_so_far) reported_finished_so_far = finished ready_refs, _ = ray.wait( @@ -462,7 +462,7 @@ def render_progress_bar(tracker, object_refs): time.sleep(0.1) pb_bar.close() - submitted, finished = ray.get(tracker.result.remote()) + submitted, finished = ray.get(tracker.metrics.remote()) if submitted != finished: print("Completed. There was state inconsistency.") from pprint import pprint diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index f97068b58377..2334d69ad7cd 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,3 +1,4 @@ +import copy import gc import heapq import logging @@ -10,6 +11,7 @@ from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray +from ray.ml import Checkpoint from ray.tune.result import NODE_IP from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI @@ -29,12 +31,21 @@ class TrackedCheckpoint: order to add metadata (e.g. the result, or the node where it has been created) and for bookkeeping purposes. + The data can be an object, a checkpoint directory, or a future to either. Because + we can't know if it's data or a directory from a future, this class expects + a ``storage_mode`` that makes the data type explicit. + + The passed metrics can be used to compare performance of different checkpoints. + The ``checkpoint_id`` is passed as an alternative to be able to order + checkpoints in time. + Args: dir_or_data: Checkpoint directory, checkpoint data, or a future to either. storage_mode: Either MEMORY or PERSISTENT. - checkpoint_id: Checkpoint number. Usually this should be monotonically + checkpoint_id: Checkpoint number. Will be used to determine checkpoint order + if metrics are not available. Usually this should be monotonically increasing for each tracked checkpoint. - result: Observed metrics for this checkpoint. This is used to determine + metrics: Observed metrics for this checkpoint. This is used to determine the value of the ``checkpoint_score_attr``. node_ip: IP of the node where the checkpoint was generated. Defaults to the current node. @@ -48,16 +59,15 @@ def __init__( dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, checkpoint_id: Optional[int] = None, - result: Optional[Dict] = None, + metrics: Optional[Dict] = None, node_ip: Optional[str] = None, ): self.dir_or_data = dir_or_data self.id = checkpoint_id self.storage_mode = storage_mode - # Todo: What to do if result is a subset of dir_or_data (dict) - self.result = result or {} - self.node_ip = node_ip or self.result.get(NODE_IP, None) + self.metrics = metrics or {} + self.node_ip = node_ip or self.metrics.get(NODE_IP, None) def commit(self, path: Optional[Path] = None) -> None: """Commit checkpoint to disk, if needed. @@ -65,7 +75,20 @@ def commit(self, path: Optional[Path] = None) -> None: Args: path: Path to commit checkpoint to. """ - pass + if self.storage_mode == TrackedCheckpoint.MEMORY: + # Do not persist memory checkpoints + return + + if not isinstance(self.dir_or_data, dict): + # Only persist dictionaries + return + + if not path: + # If no path is given, skip + return + + checkpoint = Checkpoint.from_dict(self.dir_or_data) + self.dir_or_data = checkpoint.to_directory(str(path)) def delete( self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None @@ -84,7 +107,7 @@ def delete( def __repr__(self): if self.storage_mode == TrackedCheckpoint.MEMORY: - return f"" + return f"" return ( f" str: sdk = sdk or get_anyscale_sdk() result = sdk.get_project(project_id) - return result.result.name + return result.metrics.name def get_cluster_name(cluster_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_cluster(cluster_id) - return result.result.name + return result.metrics.name diff --git a/release/ray_release/cluster_manager/full.py b/release/ray_release/cluster_manager/full.py index 32aafdf2de81..cc91380b5e15 100644 --- a/release/ray_release/cluster_manager/full.py +++ b/release/ray_release/cluster_manager/full.py @@ -37,7 +37,7 @@ def start_cluster(self, timeout: float = 600.0): idle_timeout_minutes=self.autosuspend_minutes, ) ) - self.cluster_id = result.result.id + self.cluster_id = result.metrics.id except Exception as e: raise ClusterCreationError(f"Error creating cluster: {e}") from e @@ -50,8 +50,8 @@ def start_cluster(self, timeout: float = 600.0): try: result = self.sdk.start_cluster(self.cluster_id, start_cluster_options={}) - cop_id = result.result.id - completed = result.result.completed + cop_id = result.metrics.id + completed = result.metrics.completed except Exception as e: raise ClusterStartupError( f"Error starting cluster with name " @@ -88,14 +88,14 @@ def start_cluster(self, timeout: float = 600.0): initial_retry_delay_s=2, max_retries=3, ) - completed = result.result.completed + completed = result.metrics.completed result = self.sdk.get_cluster(self.cluster_id) - if result.result.state != "Running": + if result.metrics.state != "Running": raise ClusterStartupFailed( f"Cluster did not come up - most likely the nodes are currently " f"not available. Please check the cluster startup logs: " - f"{cluster_url} (cluster state: {result.result.state})" + f"{cluster_url} (cluster state: {result.metrics.state})" ) def terminate_cluster(self, wait: bool = False): @@ -109,8 +109,8 @@ def terminate_cluster(self, wait: bool = False): return # Only do this when waiting - cop_id = result.result.id - completed = result.result.completed + cop_id = result.metrics.id + completed = result.metrics.completed while not completed: # Sleep 1 sec before next check. time.sleep(1) @@ -118,10 +118,10 @@ def terminate_cluster(self, wait: bool = False): cluster_operation_response = self.sdk.get_cluster_operation( cop_id, _request_timeout=30 ) - cluster_operation = cluster_operation_response.result + cluster_operation = cluster_operation_response.metrics completed = cluster_operation.completed result = self.sdk.get_cluster(self.cluster_id) - while result.result.state != "Terminated": + while result.metrics.state != "Terminated": time.sleep(1) result = self.sdk.get_cluster(self.cluster_id) diff --git a/release/ray_release/cluster_manager/minimal.py b/release/ray_release/cluster_manager/minimal.py index 37feee9d5c21..1be9c1ed28de 100644 --- a/release/ray_release/cluster_manager/minimal.py +++ b/release/ray_release/cluster_manager/minimal.py @@ -64,7 +64,7 @@ def create_cluster_env(self, _repeat: bool = True): config_json=self.cluster_env, ) ) - self.cluster_env_id = result.result.id + self.cluster_env_id = result.metrics.id except Exception as e: if _repeat: logger.warning( @@ -120,7 +120,7 @@ def build_cluster_env(self, timeout: float = 600.0): cluster_environment_id=self.cluster_env_id, config_json=config_json ) ) - build_id = result.result.id + build_id = result.metrics.id logger.info( f"Link to cluster env build: " @@ -150,7 +150,7 @@ def build_cluster_env(self, timeout: float = 600.0): next_report = next_report + REPORT_S result = self.sdk.get_build(build_id) - build = result.result + build = result.metrics if build.status == "failed": raise ClusterEnvBuildError( @@ -185,7 +185,7 @@ def fetch_build_info(self): assert self.cluster_env_build_id result = self.sdk.get_cluster_environment_build(self.cluster_env_build_id) - self.cluster_env = result.result.config_json + self.cluster_env = result.metrics.config_json def create_cluster_compute(self, _repeat: bool = True): assert self.cluster_compute_id is None @@ -236,7 +236,7 @@ def create_cluster_compute(self, _repeat: bool = True): config=self.cluster_compute, ) ) - self.cluster_compute_id = result.result.id + self.cluster_compute_id = result.metrics.id except Exception as e: if _repeat: logger.warning( diff --git a/release/ray_release/command_runner/sdk_runner.py b/release/ray_release/command_runner/sdk_runner.py index f582a601f06c..7d490400835c 100644 --- a/release/ray_release/command_runner/sdk_runner.py +++ b/release/ray_release/command_runner/sdk_runner.py @@ -94,10 +94,10 @@ def run_command( dict(session_id=self.cluster_manager.cluster_id, shell_command=full_command) ) - scd_id = result.result.id + scd_id = result.metrics.id self.last_command_scd_id = scd_id - completed = result.result.finished_at is not None + completed = result.metrics.finished_at is not None start_time = time.monotonic() timeout_at = start_time + timeout @@ -126,9 +126,9 @@ def run_command( initial_retry_delay_s=10, max_retries=3, ) - completed = result.result.finished_at + completed = result.metrics.finished_at - status_code = result.result.status_code + status_code = result.metrics.status_code time_taken = time.monotonic() - start_time if status_code != 0: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index b02184214bab..9090c6e52fdb 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -40,6 +40,9 @@ ) from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( + MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer, +) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ( @@ -664,7 +667,7 @@ def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done): else: step_results.update(self.evaluate()) # Collect the training results from the future. - step_results.update(train_future.result()) + step_results.update(train_future.metrics()) # Sequential: train (already done above), then eval. else: step_results.update(self.evaluate()) @@ -2124,7 +2127,7 @@ def __setstate__(self, state: dict): @DeveloperAPI def _create_local_replay_buffer_if_necessary( self, config: PartialTrainerConfigDict - ) -> Optional[MultiAgentReplayBuffer]: + ) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]: """Create a MultiAgentReplayBuffer instance if necessary. Args: @@ -2135,7 +2138,7 @@ def _create_local_replay_buffer_if_necessary( None, if local replay buffer is not needed. """ if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( - "no_local_replay_buffer" or config.get("no_local_replay_buffer") + "no_local_replay_buffer" or config.get("no_local_replay_buffer"), False ): return From cb5eed31f40b7a04d436ad82053694dfc400b6d4 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:43:37 +0100 Subject: [PATCH 35/81] [tune/train] Consolidate checkpoint manager 2: Ray Train --- python/ray/train/checkpoint.py | 311 ++++++++++++++------------------- python/ray/train/trainer.py | 11 +- 2 files changed, 136 insertions(+), 186 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index dd03ed3197eb..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,27 +1,28 @@ -import heapq import logging -import numbers -import os -from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TIMESTAMP, TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import ( + TIMESTAMP, + TRAIN_CHECKPOINT_SUBDIR, + TUNE_CHECKPOINT_FILE_NAME, + TUNE_CHECKPOINT_ID, + TUNE_INSTALLED, +) from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util import PublicAPI -from ray.util.ml_utils.util import is_nan +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager as CommonCheckpointManager, + TrackedCheckpoint, + CheckpointStrategy, +) if TUNE_INSTALLED: from ray import tune else: tune = None -MAX = "max" -MIN = "min" - logger = logging.getLogger(__name__) @@ -34,64 +35,58 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -@PublicAPI(stability="beta") -@dataclass -class CheckpointStrategy: - """Configurable parameters for defining the Train checkpointing strategy. - - Default behavior is to persist all checkpoints to disk. If - ``num_to_keep`` is set, the default retention policy is to keep the - checkpoints with maximum timestamp, i.e. the most recent checkpoints. - - Args: - num_to_keep (Optional[int]): The number of checkpoints to keep - on disk for this run. If a checkpoint is persisted to disk after - there are already this many checkpoints, then an existing - checkpoint will be deleted. If this is ``None`` then checkpoints - will not be deleted. If this is ``0`` then no checkpoints will be - persisted to disk. - checkpoint_score_attribute (str): The attribute that will be used to - score checkpoints to determine which checkpoints should be kept - on disk when there are greater than ``num_to_keep`` checkpoints. - This attribute must be a key from the checkpoint - dictionary which has a numerical value. - checkpoint_score_order (str). Either "max" or "min". - If "max", then checkpoints with highest values of - ``checkpoint_score_attribute`` will be kept. - If "min", then checkpoints with lowest values of - ``checkpoint_score_attribute`` will be kept. +class _NotYetPersistedCheckpoint(TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. """ - num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TIMESTAMP - checkpoint_score_order: str = MAX + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def __post_init__(self): - if self.num_to_keep is not None and self.num_to_keep < 0: - raise ValueError( - f"Received invalidate num_to_keep: " - f"{self.num_to_keep}. " - f"Must be None or non-negative integer." - ) - if self.checkpoint_score_order not in (MAX, MIN): - raise ValueError( - f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' - ) + self._data_to_commit = self.dir_or_data + self.dir_or_data = None + @property + def committed(self) -> bool: + return not self._data_to_commit -class PersistedCheckpoint: - def __init__(self, path, priority): - self.path = path - self.priority = priority + def commit(self, path: Optional[Path] = None): + if self.committed: + return - def __lt__(self, other): - return self.priority < other.priority + assert path - def __repr__(self): - return f"PersistedCheckpoint({repr(self.path)})" + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self._data_to_commit, f) + logger.debug(f"Checkpoint successfully written to: {path}") + self.dir_or_data = path + self._data_to_commit = None -class CheckpointManager: + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): + if not self.committed: + return + return super().delete(delete_fn=delete_fn) + + @classmethod + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): + new_checkpoint = cls( + dir_or_data=checkpoint.dir_or_data, + storage_mode=TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + ) + return new_checkpoint + + +class CheckpointManager(CommonCheckpointManager): """Manages checkpoint processing, writing, and loading. @@ -119,55 +114,16 @@ class CheckpointManager: checkpoint may not be saved to disk. """ - def on_init(self, **kwargs): - """Checkpoint code executed during BackendExecutor init.""" - self.latest_checkpoint = None - - # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = 0 - - # Used for keeping top K checkpoints. - self._top_persisted_checkpoints = [] - - # Best checkpoint altogether. - # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint = None - - def on_start_training( - self, - checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: Path, - latest_checkpoint_id: Optional[int] = None, - ): - """Checkpoint code executed during BackendExecutor start_training.""" - # Restart checkpointing. - self._latest_checkpoint_id = latest_checkpoint_id if latest_checkpoint_id else 0 - self._checkpoint_strategy = ( - CheckpointStrategy() if checkpoint_strategy is None else checkpoint_strategy - ) + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Perform all processing for a checkpoint.""" - - # Get checkpoint from first worker. - checkpoint = checkpoint_results[0].data + super().__init__(checkpoint_strategy=checkpoint_strategy) - # Decode checkpoint. - checkpoint = decode_checkpoint_fn(checkpoint) + self._validate_checkpoint_strategy() - # Store checkpoint in memory. - self.latest_checkpoint = checkpoint - - # Write checkpoint to disk. - self.write_checkpoint(checkpoint) - - # Increment checkpoint id. - self._latest_checkpoint_id += 1 + def _validate_checkpoint_strategy(self): + if self._checkpoint_strategy.checkpoint_score_attribute is None: + self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]] @@ -181,94 +137,77 @@ def _load_checkpoint( # Load checkpoint from path. return load_checkpoint_from_path(checkpoint_to_load) - def write_checkpoint(self, checkpoint: Dict): - """Writes checkpoint to disk.""" - num_to_keep = self._checkpoint_strategy.num_to_keep + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Ray Train entrypoint. Perform all processing for a checkpoint.""" + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data - if num_to_keep == 0: - # Checkpoints should not be persisted to disk. - return + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) - checkpoint_score_attribute = ( - self._checkpoint_strategy.checkpoint_score_attribute - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_attribute not in checkpoint: + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + if ( + self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_data + ): raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " + f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) - checkpoint_score = checkpoint[checkpoint_score_attribute] - if not isinstance(checkpoint_score, numbers.Number): - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute} with value " - f"{checkpoint_score}. " - f"This attribute must be numerical." + tracked_checkpoint = TrackedCheckpoint( + dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + ) + self.register_checkpoint(checkpoint=tracked_checkpoint) + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - def priority(checkpoint_score_order, checkpoint_score): - # Treat NaN as worst - # The tuple structure is (not is_nan(), metric), which makes - # the nan values to be always considered as the worst - # metrics by the heap - if checkpoint_score_order != MAX: - checkpoint_score = -checkpoint_score - return (not is_nan(checkpoint_score), checkpoint_score) + self._latest_checkpoint_id += 1 + + def _get_next_checkpoint_path(self) -> Optional[Path]: + """Path to the next checkpoint to persist.""" + checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + return self.latest_checkpoint_dir.joinpath(checkpoint_file) - checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) + def on_start_training( + self, + checkpoint_strategy: Optional[CheckpointStrategy], + run_dir: str, + latest_checkpoint_id: Optional[int] = 0, + ): + checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self._checkpoint_strategy = checkpoint_strategy - persisted_checkpoint = PersistedCheckpoint( - self.next_checkpoint_path, checkpoint_priority - ) + self._validate_checkpoint_strategy() - def write_to_disk(path: Path): - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(checkpoint, f) - logger.debug(f"Checkpoint successfully written to: " f"{path}") - - def remove_from_disk(path: Path): - os.remove(path) - - if num_to_keep is None: - # Keep all checkpoints. - write_to_disk(self.next_checkpoint_path) - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep first num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) - elif ( - persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority - ): - # Keep top num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, persisted_checkpoint - ) - worst_checkpoint_path = worst_checkpoint.path - remove_from_disk(worst_checkpoint_path) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") - else: - # If the latest checkpoint has the same or lower priority, skip it. - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." - ) + self.run_dir = run_dir + self._latest_checkpoint_id = latest_checkpoint_id or 0 - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint = persisted_checkpoint + # Train-specific attributes + @property + def latest_checkpoint(self): + if not self._latest_memory_checkpoint: + return None + return self._latest_memory_checkpoint.dir_or_data @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -294,7 +233,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return self._best_persisted_checkpoint.path + return Path(self._best_persisted_checkpoint.dir_or_data) else: return None @@ -328,16 +267,22 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - with file_path.open("wb") as f: - cloudpickle.dump(checkpoint, f) + checkpoint.commit(file_path) + + return super()._decide_what_to_do_with_checkpoint(checkpoint) def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 798dd99454a0..efb39e9c3535 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,11 +224,16 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) + # Todo (krfricke): Initialize checkpoint manager here with final values + # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager() + self.checkpoint_manager = TuneCheckpointManager( + checkpoint_strategy=None, run_dir=None + ) else: - self.checkpoint_manager = CheckpointManager() - self.checkpoint_manager.on_init() + self.checkpoint_manager = CheckpointManager( + checkpoint_strategy=None, run_dir=None + ) def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From 0e6c420e47a4cc6b3d99df510002dc182da4e251 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:40:35 +0100 Subject: [PATCH 36/81] Adjust to changes in base PR --- python/ray/train/checkpoint.py | 81 +++----------------------- python/ray/train/tests/test_trainer.py | 1 + 2 files changed, 8 insertions(+), 74 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ffbddb5bd4a8..1bc2349ff5b7 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray import cloudpickle +from ray.ml import Checkpoint from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, @@ -31,59 +31,8 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - with checkpoint_path.open("rb") as f: - return cloudpickle.load(f) - - -class _NotYetPersistedCheckpoint(TrackedCheckpoint): - """Tracked checkpoint that is not yet persisted to disk. - - This checkpoint class supports lazy writing. The checkpoint manager will - only call ``commit()`` if the checkpoint should be kept on disk. This class - will only then write checkpoint data to disk. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._data_to_commit = self.dir_or_data - self.dir_or_data = None - - @property - def committed(self) -> bool: - return not self._data_to_commit - - def commit(self, path: Optional[Path] = None): - if self.committed: - return - - assert path - - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(self._data_to_commit, f) - logger.debug(f"Checkpoint successfully written to: {path}") - - self.dir_or_data = path - self._data_to_commit = None - - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - if not self.committed: - return - return super().delete(delete_fn=delete_fn) - - @classmethod - def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): - new_checkpoint = cls( - dir_or_data=checkpoint.dir_or_data, - storage_mode=TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.id, - result=checkpoint.result, - node_ip=checkpoint.node_ip, - ) - return new_checkpoint + checkpoint = Checkpoint.from_directory(str(checkpoint_path)) + return checkpoint.to_dict() class CheckpointManager(CommonCheckpointManager): @@ -114,6 +63,8 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ + _persist_memory_checkpoints = True + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir @@ -166,23 +117,10 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(checkpoint) - - # Only process further if we consider keeping this checkpoint on disk - if self._checkpoint_strategy.num_to_keep != 0: - not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) - ) - self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - - self._latest_checkpoint_id += 1 - def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) @@ -267,12 +205,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint( - self, checkpoint: _NotYetPersistedCheckpoint - ): - assert isinstance(checkpoint, _NotYetPersistedCheckpoint) - assert not checkpoint.committed - + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index dc72e21a939f..cc017de709b9 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -15,6 +15,7 @@ from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV from ray.train.torch import TorchConfig from ray.train.tensorflow import TensorflowConfig + from ray.train.horovod import HorovodConfig from ray.train.callbacks.callback import TrainingCallback from ray.train.worker_group import WorkerGroup From a8928aa43983f698fe613e760d07628e8c19e9f7 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:50:14 +0100 Subject: [PATCH 37/81] Adjust to changes in base PR --- python/ray/tune/checkpoint_manager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 3d8b27f91de2..1c1ff068d9c3 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -10,6 +10,8 @@ CheckpointManager as CommonCheckpointManager, TrackedCheckpoint, ) +from ray.util.ml_utils.dict import flatten_dict +from ray.util.ml_utils.util import is_nan logger = logging.getLogger(__name__) @@ -115,8 +117,8 @@ def newest_memory_checkpoint(self): def best_checkpoints(self): """Returns best PERSISTENT checkpoints, sorted by score.""" - checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority) - return [queue_item.value for queue_item in checkpoints] + checkpoints = sorted(self._top_persisted_checkpoints, key=lambda c: c.priority) + return [wrapped.tracked_checkpoint for wrapped in checkpoints] def _priority(self, checkpoint): result = flatten_dict(checkpoint.metrics) @@ -132,8 +134,8 @@ def _priority(self, checkpoint): def __getstate__(self): state = self.__dict__.copy() # Avoid serializing the memory checkpoint. - state["_newest_memory_checkpoint"] = _TuneCheckpoint( - _TuneCheckpoint.MEMORY, None + state["_newest_memory_checkpoint"] = TrackedCheckpoint( + TrackedCheckpoint.MEMORY, None ) # Avoid serializing lambda since it may capture cyclical dependencies. state.pop("delete") From 2bd5f750e1684f8233c5bf2a04ab5458cdd3f6f9 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:41:35 +0100 Subject: [PATCH 38/81] [tune/train] Consolidate checkpoint manager 1: Common checkpoint manager class --- .../ray/util/ml_utils/checkpoint_manager.py | 378 ++++++++++++++++++ 1 file changed, 378 insertions(+) create mode 100644 python/ray/util/ml_utils/checkpoint_manager.py diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py new file mode 100644 index 000000000000..f97068b58377 --- /dev/null +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -0,0 +1,378 @@ +import gc +import heapq +import logging +import numbers +import os +import shutil + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Dict, Union, Callable, Tuple, List, Any + +import ray +from ray.tune.result import NODE_IP +from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.util import is_nan + +MAX = "max" +MIN = "min" + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class TrackedCheckpoint: + """Checkpoint tracked by a checkpoint manager. + + This class is used to track checkpoints generated by trainables and trainers in + order to add metadata (e.g. the result, or the node where it has been created) + and for bookkeeping purposes. + + Args: + dir_or_data: Checkpoint directory, checkpoint data, or a future to either. + storage_mode: Either MEMORY or PERSISTENT. + checkpoint_id: Checkpoint number. Usually this should be monotonically + increasing for each tracked checkpoint. + result: Observed metrics for this checkpoint. This is used to determine + the value of the ``checkpoint_score_attr``. + node_ip: IP of the node where the checkpoint was generated. Defaults + to the current node. + """ + + MEMORY = "memory" + PERSISTENT = "persistent" + + def __init__( + self, + dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], + storage_mode: str, + checkpoint_id: Optional[int] = None, + result: Optional[Dict] = None, + node_ip: Optional[str] = None, + ): + self.dir_or_data = dir_or_data + self.id = checkpoint_id + self.storage_mode = storage_mode + + # Todo: What to do if result is a subset of dir_or_data (dict) + self.result = result or {} + self.node_ip = node_ip or self.result.get(NODE_IP, None) + + def commit(self, path: Optional[Path] = None) -> None: + """Commit checkpoint to disk, if needed. + + Args: + path: Path to commit checkpoint to. + """ + pass + + def delete( + self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None + ) -> None: + """Delete checkpoint from disk, if needed. + + Args: + delete_fn: Function to be called with the tracked checkpoint as an + argument. Defaults to removing the local directory/file. + """ + delete_fn = delete_fn or _default_delete_fn + try: + delete_fn(self) + except Exception as e: + logger.warning(f"Checkpoint deletion failed: {e}") + + def __repr__(self): + if self.storage_mode == TrackedCheckpoint.MEMORY: + return f"" + + return ( + f"" + ) + + +def _default_delete_fn(checkpoint: TrackedCheckpoint): + if checkpoint.storage_mode != TrackedCheckpoint.PERSISTENT: + return + + if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): + if os.path.isfile(checkpoint.dir_or_data): + os.remove(checkpoint.dir_or_data) + return + elif os.path.isdir(checkpoint.dir_or_data): + shutil.rmtree(checkpoint.dir_or_data) + return + raise RuntimeError( + f"Could not delete checkpoint {checkpoint} from disk as it is " + f"neither file not directory. Path: {checkpoint.dir_or_data}." + ) + + +class _HeapCheckpointWrapper: + def __init__(self, priority: Any, tracked_checkpoint: TrackedCheckpoint): + self.priority = priority + self.tracked_checkpoint = tracked_checkpoint + + def __lt__(self, other): + return self.priority < other.priority + + def __repr__(self): + return f"_HeapCheckpoint({repr(self.tracked_checkpoint)})" + + +@PublicAPI(stability="beta") +@dataclass +class CheckpointStrategy: + """Configurable parameters for defining the checkpointing strategy. + + Default behavior is to persist all checkpoints to disk. If + ``num_to_keep`` is set, the default retention policy is to keep the + checkpoints with maximum timestamp, i.e. the most recent checkpoints. + + Args: + num_to_keep (Optional[int]): The number of checkpoints to keep + on disk for this run. If a checkpoint is persisted to disk after + there are already this many checkpoints, then an existing + checkpoint will be deleted. If this is ``None`` then checkpoints + will not be deleted. If this is ``0`` then no checkpoints will be + persisted to disk. + checkpoint_score_attribute (str): The attribute that will be used to + score checkpoints to determine which checkpoints should be kept + on disk when there are greater than ``num_to_keep`` checkpoints. + This attribute must be a key from the checkpoint + dictionary which has a numerical value. Per default, the last + checkpoints will be kept. + checkpoint_score_order (str). Either "max" or "min". + If "max", then checkpoints with highest values of + ``checkpoint_score_attribute`` will be kept. + If "min", then checkpoints with lowest values of + ``checkpoint_score_attribute`` will be kept. + """ + + num_to_keep: Optional[int] = None + checkpoint_score_attribute: Optional[str] = None + checkpoint_score_order: str = MAX + + def __post_init__(self): + if self.num_to_keep is not None and self.num_to_keep < 0: + raise ValueError( + f"Received invalid num_to_keep: " + f"{self.num_to_keep}. " + f"Must be None or non-negative integer." + ) + if self.checkpoint_score_order not in (MAX, MIN): + raise ValueError( + f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' + ) + + +class CheckpointManager: + """Common checkpoint management and bookkeeping class for Ray Train and Tune. + + This class acts as the common core for checkpoint bookkeeping in Ray ML libraries. + On a high level, this manager keeps a reference to all stored checkpoints + (both in-memory and on-disk checkpoints). For on-disk checkpoints, it + keeps a configured number of checkpoints according to specified metrics. + + The manager supports lazy data writing by utilizing the + ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + should be persisted to disk. + """ + + def __init__( + self, + checkpoint_strategy: CheckpointStrategy, + latest_checkpoint_id: int = 0, + delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, + ): + self._checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + + # Incremental unique checkpoint ID of this run. + self._latest_checkpoint_id = latest_checkpoint_id + + # Used for keeping top K checkpoints. + self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] + + # Best checkpoint altogether. + # Used for exposing best_checkpoint_path. + self._best_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_persisted_checkpoint: Optional[TrackedCheckpoint] = None + self._latest_memory_checkpoint: Optional[TrackedCheckpoint] = None + + # Checkpoints that are not immediately removed + self._checkpoints_to_clean_up = set() + self._delete_fn = delete_fn + + def set_delete_fn(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]]): + """Update the function called to delete persisted checkpoints. + + Args: + delete_fn: Function that takes a tracked checkpoint as an argument and + deletes it from disk. + """ + self._delete_fn = delete_fn + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + """Register new checkpoint and add to bookkeeping. + + This method will register a new checkpoint and add it to the internal + bookkeeping logic. This means the checkpoint manager will decide if + this checkpoint should be kept, and if older or worse performing + checkpoints should be deleted. + + Subclasses have to implement this method. + + Args: + checkpoint: Tracked checkpoint object to add to bookkeeping. + """ + raise NotImplementedError + + def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): + assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY + self._latest_memory_checkpoint = memory_checkpoint + # Avoid memory leaks on k8s pods + gc.collect() + + def _replace_latest_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint + self._latest_persisted_checkpoint = persisted_checkpoint + + if self._checkpoint_strategy.num_to_keep == 0: + self._maybe_delete_persisted_checkpoint( + second_to_latest_persisted_checkpoint + ) + + def _maybe_replace_best_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + if self._best_persisted_checkpoint is None: + self._best_persisted_checkpoint = persisted_checkpoint + else: + old_score = self._get_checkpoint_score(self._best_persisted_checkpoint) + candidate_score = self._get_checkpoint_score(persisted_checkpoint) + if candidate_score >= old_score: + self._best_persisted_checkpoint = persisted_checkpoint + + def _get_checkpoint_score( + self, checkpoint: TrackedCheckpoint + ) -> Tuple[bool, numbers.Number, int]: + checkpoint_score_attribute = ( + self._checkpoint_strategy.checkpoint_score_attribute + ) + if checkpoint_score_attribute not in checkpoint.result: + logger.error( + f"Result dict has no key: {checkpoint_score_attribute}. " + f"checkpoint_score_attr must be set to a key in the " + f"result dict. Valid keys are: {list(checkpoint.result.keys())}" + ) + checkpoint_result = float("-inf") + else: + checkpoint_result = checkpoint.result[checkpoint_score_attribute] + + checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order + if checkpoint_score_order == MAX: + order_factor = 1.0 + else: + order_factor = -1.0 + + checkpoint_score = order_factor * checkpoint_result + + if not isinstance(checkpoint_score, numbers.Number): + raise ValueError( + f"Unable to persist checkpoint for " + f"checkpoint_score_attribute: " + f"{checkpoint_score_attribute} with value " + f"{checkpoint_score}. " + f"This attribute must be numerical." + ) + + return ( + not is_nan(checkpoint_score), + checkpoint_score if not is_nan(checkpoint_score) else 0, + checkpoint.id, + ) + + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): + checkpoint_score = self._get_checkpoint_score(checkpoint) + wrapped_checkpoint = _HeapCheckpointWrapper( + priority=checkpoint_score, tracked_checkpoint=checkpoint + ) + + if self._checkpoint_strategy.num_to_keep is None: + # Keep all checkpoints + checkpoint.commit(path=self._get_next_checkpoint_path()) + self._replace_latest_persisted_checkpoint(checkpoint) + self._top_persisted_checkpoints.append(wrapped_checkpoint) + elif ( + len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep + ): + # Heap is not full yet, so keep this checkpoint + checkpoint.commit(path=self._get_next_checkpoint_path()) + heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) + self._replace_latest_persisted_checkpoint(checkpoint) + elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority: + # Priority is higher than current worst checkpoint, so replace worst + checkpoint.commit(path=self._get_next_checkpoint_path()) + worst_checkpoint = heapq.heappushpop( + self._top_persisted_checkpoints, wrapped_checkpoint + ).tracked_checkpoint + + # Only remove if checkpoint data is different + if worst_checkpoint.dir_or_data != checkpoint.dir_or_data: + self._maybe_delete_persisted_checkpoint(worst_checkpoint) + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") + + self._replace_latest_persisted_checkpoint(checkpoint) + else: + # If the latest checkpoint has the same or lower priority, skip it. + self._skip_persisted_checkpoint(checkpoint) + + self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) + self._cleanup_checkpoints() + + def _maybe_delete_persisted_checkpoint( + self, persisted_checkpoint: TrackedCheckpoint + ): + if persisted_checkpoint == self._latest_persisted_checkpoint: + self._checkpoints_to_clean_up.add(persisted_checkpoint) + else: + self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) + + def _delete_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + persisted_checkpoint.delete(delete_fn=self._delete_fn) + self._checkpoints_to_clean_up.discard(persisted_checkpoint) + + def _cleanup_checkpoints(self): + for checkpoint in list(self._checkpoints_to_clean_up): + self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) + + def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") + self._checkpoints_to_clean_up.add(persisted_checkpoint) + + def _get_next_checkpoint_path(self) -> Optional[Path]: + return None + + def __del__(self): + self._cleanup_checkpoints() + + def __getstate__(self): + state = self.__dict__.copy() + + # Do not serialize the delete fn + state.pop("_delete_fn", None) + + # Avoid serializing the memory checkpoint. + state["_newest_memory_checkpoint"] = TrackedCheckpoint( + dir_or_data=None, + checkpoint_id=0, + storage_mode=TrackedCheckpoint.MEMORY, + ) + return state + + def __setstate__(self, state): + state["_delete_fn"] = None + self.__dict__.update(state) From eb3a658019fc217e3fad417f067012af06deb1b1 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:34:38 +0100 Subject: [PATCH 39/81] Add general entrypoint --- dashboard/modules/job/job_manager.py | 2 +- dashboard/optional_utils.py | 2 +- python/ray/_private/gcs_pubsub.py | 6 +- python/ray/tests/test_async.py | 2 +- python/ray/tests/test_client_proxy.py | 2 +- .../ray/tune/analysis/experiment_analysis.py | 2 +- python/ray/tune/checkpoint_manager.py | 2 +- python/ray/tune/examples/tf_mnist_example.py | 8 +- python/ray/tune/trial.py | 2 +- python/ray/util/client/server/proxier.py | 2 +- python/ray/util/client/worker.py | 2 +- python/ray/util/dask/scheduler.py | 6 +- .../ray/util/ml_utils/checkpoint_manager.py | 80 +++++++++++++++---- release/ray_release/anyscale_util.py | 4 +- release/ray_release/cluster_manager/full.py | 20 ++--- .../ray_release/cluster_manager/minimal.py | 10 +-- .../ray_release/command_runner/sdk_runner.py | 8 +- rllib/agents/trainer.py | 9 ++- 18 files changed, 112 insertions(+), 57 deletions(-) diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index a38a59e6e220..758a1ca57c92 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -260,7 +260,7 @@ async def run( # at the same time assert len(finished) == 1, "Should have only one coroutine done" [child_process_task] = finished - return_code = child_process_task.result() + return_code = child_process_task.metrics() if return_code == 0: self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED) else: diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 284cb8ef9c19..59908d9be11e 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -208,7 +208,7 @@ async def _cache_handler(*args) -> aiohttp.web.Response: def _update_cache(task): try: - response = task.result() + response = task.metrics() except Exception: response = rest_response( success=False, message=traceback.format_exc() diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 12bae7d64dfb..1aca9d986d29 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -246,7 +246,7 @@ def _poll_locked(self, timeout=None) -> None: try: # Use 1s timeout to check for subscriber closing # periodically. - fut.result(timeout=1) + fut.metrics(timeout=1) break except grpc.FutureTimeoutError: # Subscriber has closed. Cancel inflight request and @@ -262,8 +262,8 @@ def _poll_locked(self, timeout=None) -> None: raise if fut.done(): - self._last_batch_size = len(fut.result().pub_messages) - for msg in fut.result().pub_messages: + self._last_batch_size = len(fut.metrics().pub_messages) + for msg in fut.metrics().pub_messages: if msg.channel_type != self._channel: logger.warn(f"Ignoring message from unsubscribed channel {msg}") continue diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index ac39fe681ecd..c8f3b9c65300 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -137,7 +137,7 @@ def test_concurrent_future(ray_start_regular_shared): def cb(fut): nonlocal global_result - global_result = fut.result() + global_result = fut.metrics() fut.add_done_callback(cb) assert global_result == 1 diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index d0d1d2c307d1..5c34bcea173f 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -52,7 +52,7 @@ def test_proxy_manager_lifecycle(shutdown_only): pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server - grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) + grpc.channel_ready_future(pm.get_channel(client)).metrics(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == free_ports[0], f"Free Ports are: {free_ports}" diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 5d1bd3621f1e..ed0255365040 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -423,7 +423,7 @@ def get_trial_checkpoints_paths( # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [ - (c.value, unflattened_lookup(metric, c.result)) for c in checkpoints + (c.value, unflattened_lookup(metric, c.metrics)) for c in checkpoints ] else: raise ValueError("trial should be a string or a Trial instance.") diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 75cf4b8cb835..321f1a132d39 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -214,7 +214,7 @@ def best_checkpoints(self): return [queue_item.value for queue_item in checkpoints] def _priority(self, checkpoint): - result = flatten_dict(checkpoint.result) + result = flatten_dict(checkpoint.metrics) priority = result[self._checkpoint_score_attr] if self._checkpoint_score_desc: priority = -priority diff --git a/python/ray/tune/examples/tf_mnist_example.py b/python/ray/tune/examples/tf_mnist_example.py index e2bde13e4106..1468a18d6d90 100644 --- a/python/ray/tune/examples/tf_mnist_example.py +++ b/python/ray/tune/examples/tf_mnist_example.py @@ -109,10 +109,10 @@ def step(self): # It is important to return tf.Tensors as numpy objects. return { "epoch": self.iteration, - "loss": self.train_loss.result().numpy(), - "accuracy": self.train_accuracy.result().numpy() * 100, - "test_loss": self.test_loss.result().numpy(), - "mean_accuracy": self.test_accuracy.result().numpy() * 100, + "loss": self.train_loss.metrics().numpy(), + "accuracy": self.train_accuracy.metrics().numpy() * 100, + "test_loss": self.test_loss.metrics().numpy(), + "mean_accuracy": self.test_accuracy.metrics().numpy() * 100, } diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f793714262f2..466bbc606d0d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -660,7 +660,7 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): def on_restore(self): """Handles restoration completion.""" assert self.is_restoring - self.last_result = self.restoring_from.result + self.last_result = self.restoring_from.metrics self.restoring_from = None self.invalidate_json_state() diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index e83ccd82d6fa..eee6f07e98d2 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -370,7 +370,7 @@ def get_channel( # Wait for the SpecificServer to become ready. server.wait_ready() try: - grpc.channel_ready_future(server.channel).result( + grpc.channel_ready_future(server.channel).metrics( timeout=CHECK_CHANNEL_TIMEOUT_S ) return server.channel diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 1f44e0287503..bc2548db62fe 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -216,7 +216,7 @@ def _connect_channel(self, reconnecting=False) -> None: try: # Let gRPC wait for us to see if the channel becomes ready. # If it throws, we couldn't connect. - grpc.channel_ready_future(self.channel).result(timeout=timeout) + grpc.channel_ready_future(self.channel).metrics(timeout=timeout) # The HTTP2 channel is ready. Wrap the channel with the # RayletDriverStub, allowing for unary requests. self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 9c448474436c..469770e02435 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -442,7 +442,7 @@ def render_progress_bar(tracker, object_refs): from tqdm import tqdm # At this time, every task should be submitted. - total, finished = ray.get(tracker.result.remote()) + total, finished = ray.get(tracker.metrics.remote()) reported_finished_so_far = 0 pb_bar = tqdm(total=total, position=0) pb_bar.set_description("") @@ -450,7 +450,7 @@ def render_progress_bar(tracker, object_refs): ready_refs = [] while finished < total: - submitted, finished = ray.get(tracker.result.remote()) + submitted, finished = ray.get(tracker.metrics.remote()) pb_bar.update(finished - reported_finished_so_far) reported_finished_so_far = finished ready_refs, _ = ray.wait( @@ -462,7 +462,7 @@ def render_progress_bar(tracker, object_refs): time.sleep(0.1) pb_bar.close() - submitted, finished = ray.get(tracker.result.remote()) + submitted, finished = ray.get(tracker.metrics.remote()) if submitted != finished: print("Completed. There was state inconsistency.") from pprint import pprint diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index f97068b58377..2334d69ad7cd 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,3 +1,4 @@ +import copy import gc import heapq import logging @@ -10,6 +11,7 @@ from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray +from ray.ml import Checkpoint from ray.tune.result import NODE_IP from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI @@ -29,12 +31,21 @@ class TrackedCheckpoint: order to add metadata (e.g. the result, or the node where it has been created) and for bookkeeping purposes. + The data can be an object, a checkpoint directory, or a future to either. Because + we can't know if it's data or a directory from a future, this class expects + a ``storage_mode`` that makes the data type explicit. + + The passed metrics can be used to compare performance of different checkpoints. + The ``checkpoint_id`` is passed as an alternative to be able to order + checkpoints in time. + Args: dir_or_data: Checkpoint directory, checkpoint data, or a future to either. storage_mode: Either MEMORY or PERSISTENT. - checkpoint_id: Checkpoint number. Usually this should be monotonically + checkpoint_id: Checkpoint number. Will be used to determine checkpoint order + if metrics are not available. Usually this should be monotonically increasing for each tracked checkpoint. - result: Observed metrics for this checkpoint. This is used to determine + metrics: Observed metrics for this checkpoint. This is used to determine the value of the ``checkpoint_score_attr``. node_ip: IP of the node where the checkpoint was generated. Defaults to the current node. @@ -48,16 +59,15 @@ def __init__( dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, checkpoint_id: Optional[int] = None, - result: Optional[Dict] = None, + metrics: Optional[Dict] = None, node_ip: Optional[str] = None, ): self.dir_or_data = dir_or_data self.id = checkpoint_id self.storage_mode = storage_mode - # Todo: What to do if result is a subset of dir_or_data (dict) - self.result = result or {} - self.node_ip = node_ip or self.result.get(NODE_IP, None) + self.metrics = metrics or {} + self.node_ip = node_ip or self.metrics.get(NODE_IP, None) def commit(self, path: Optional[Path] = None) -> None: """Commit checkpoint to disk, if needed. @@ -65,7 +75,20 @@ def commit(self, path: Optional[Path] = None) -> None: Args: path: Path to commit checkpoint to. """ - pass + if self.storage_mode == TrackedCheckpoint.MEMORY: + # Do not persist memory checkpoints + return + + if not isinstance(self.dir_or_data, dict): + # Only persist dictionaries + return + + if not path: + # If no path is given, skip + return + + checkpoint = Checkpoint.from_dict(self.dir_or_data) + self.dir_or_data = checkpoint.to_directory(str(path)) def delete( self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None @@ -84,7 +107,7 @@ def delete( def __repr__(self): if self.storage_mode == TrackedCheckpoint.MEMORY: - return f"" + return f"" return ( f" str: sdk = sdk or get_anyscale_sdk() result = sdk.get_project(project_id) - return result.result.name + return result.metrics.name def get_cluster_name(cluster_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_cluster(cluster_id) - return result.result.name + return result.metrics.name diff --git a/release/ray_release/cluster_manager/full.py b/release/ray_release/cluster_manager/full.py index 32aafdf2de81..cc91380b5e15 100644 --- a/release/ray_release/cluster_manager/full.py +++ b/release/ray_release/cluster_manager/full.py @@ -37,7 +37,7 @@ def start_cluster(self, timeout: float = 600.0): idle_timeout_minutes=self.autosuspend_minutes, ) ) - self.cluster_id = result.result.id + self.cluster_id = result.metrics.id except Exception as e: raise ClusterCreationError(f"Error creating cluster: {e}") from e @@ -50,8 +50,8 @@ def start_cluster(self, timeout: float = 600.0): try: result = self.sdk.start_cluster(self.cluster_id, start_cluster_options={}) - cop_id = result.result.id - completed = result.result.completed + cop_id = result.metrics.id + completed = result.metrics.completed except Exception as e: raise ClusterStartupError( f"Error starting cluster with name " @@ -88,14 +88,14 @@ def start_cluster(self, timeout: float = 600.0): initial_retry_delay_s=2, max_retries=3, ) - completed = result.result.completed + completed = result.metrics.completed result = self.sdk.get_cluster(self.cluster_id) - if result.result.state != "Running": + if result.metrics.state != "Running": raise ClusterStartupFailed( f"Cluster did not come up - most likely the nodes are currently " f"not available. Please check the cluster startup logs: " - f"{cluster_url} (cluster state: {result.result.state})" + f"{cluster_url} (cluster state: {result.metrics.state})" ) def terminate_cluster(self, wait: bool = False): @@ -109,8 +109,8 @@ def terminate_cluster(self, wait: bool = False): return # Only do this when waiting - cop_id = result.result.id - completed = result.result.completed + cop_id = result.metrics.id + completed = result.metrics.completed while not completed: # Sleep 1 sec before next check. time.sleep(1) @@ -118,10 +118,10 @@ def terminate_cluster(self, wait: bool = False): cluster_operation_response = self.sdk.get_cluster_operation( cop_id, _request_timeout=30 ) - cluster_operation = cluster_operation_response.result + cluster_operation = cluster_operation_response.metrics completed = cluster_operation.completed result = self.sdk.get_cluster(self.cluster_id) - while result.result.state != "Terminated": + while result.metrics.state != "Terminated": time.sleep(1) result = self.sdk.get_cluster(self.cluster_id) diff --git a/release/ray_release/cluster_manager/minimal.py b/release/ray_release/cluster_manager/minimal.py index 37feee9d5c21..1be9c1ed28de 100644 --- a/release/ray_release/cluster_manager/minimal.py +++ b/release/ray_release/cluster_manager/minimal.py @@ -64,7 +64,7 @@ def create_cluster_env(self, _repeat: bool = True): config_json=self.cluster_env, ) ) - self.cluster_env_id = result.result.id + self.cluster_env_id = result.metrics.id except Exception as e: if _repeat: logger.warning( @@ -120,7 +120,7 @@ def build_cluster_env(self, timeout: float = 600.0): cluster_environment_id=self.cluster_env_id, config_json=config_json ) ) - build_id = result.result.id + build_id = result.metrics.id logger.info( f"Link to cluster env build: " @@ -150,7 +150,7 @@ def build_cluster_env(self, timeout: float = 600.0): next_report = next_report + REPORT_S result = self.sdk.get_build(build_id) - build = result.result + build = result.metrics if build.status == "failed": raise ClusterEnvBuildError( @@ -185,7 +185,7 @@ def fetch_build_info(self): assert self.cluster_env_build_id result = self.sdk.get_cluster_environment_build(self.cluster_env_build_id) - self.cluster_env = result.result.config_json + self.cluster_env = result.metrics.config_json def create_cluster_compute(self, _repeat: bool = True): assert self.cluster_compute_id is None @@ -236,7 +236,7 @@ def create_cluster_compute(self, _repeat: bool = True): config=self.cluster_compute, ) ) - self.cluster_compute_id = result.result.id + self.cluster_compute_id = result.metrics.id except Exception as e: if _repeat: logger.warning( diff --git a/release/ray_release/command_runner/sdk_runner.py b/release/ray_release/command_runner/sdk_runner.py index f582a601f06c..7d490400835c 100644 --- a/release/ray_release/command_runner/sdk_runner.py +++ b/release/ray_release/command_runner/sdk_runner.py @@ -94,10 +94,10 @@ def run_command( dict(session_id=self.cluster_manager.cluster_id, shell_command=full_command) ) - scd_id = result.result.id + scd_id = result.metrics.id self.last_command_scd_id = scd_id - completed = result.result.finished_at is not None + completed = result.metrics.finished_at is not None start_time = time.monotonic() timeout_at = start_time + timeout @@ -126,9 +126,9 @@ def run_command( initial_retry_delay_s=10, max_retries=3, ) - completed = result.result.finished_at + completed = result.metrics.finished_at - status_code = result.result.status_code + status_code = result.metrics.status_code time_taken = time.monotonic() - start_time if status_code != 0: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index b02184214bab..9090c6e52fdb 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -40,6 +40,9 @@ ) from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( + MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer, +) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ( @@ -664,7 +667,7 @@ def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done): else: step_results.update(self.evaluate()) # Collect the training results from the future. - step_results.update(train_future.result()) + step_results.update(train_future.metrics()) # Sequential: train (already done above), then eval. else: step_results.update(self.evaluate()) @@ -2124,7 +2127,7 @@ def __setstate__(self, state: dict): @DeveloperAPI def _create_local_replay_buffer_if_necessary( self, config: PartialTrainerConfigDict - ) -> Optional[MultiAgentReplayBuffer]: + ) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]: """Create a MultiAgentReplayBuffer instance if necessary. Args: @@ -2135,7 +2138,7 @@ def _create_local_replay_buffer_if_necessary( None, if local replay buffer is not needed. """ if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( - "no_local_replay_buffer" or config.get("no_local_replay_buffer") + "no_local_replay_buffer" or config.get("no_local_replay_buffer"), False ): return From 886157b0ebb04a6127ce27363fa8ae5dc673ac8d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:57:52 +0100 Subject: [PATCH 40/81] Fix faulty rebase --- dashboard/modules/job/job_manager.py | 2 +- dashboard/optional_utils.py | 2 +- python/ray/_private/gcs_pubsub.py | 6 +++--- python/ray/tests/test_async.py | 2 +- python/ray/tests/test_client_proxy.py | 2 +- python/ray/tune/examples/tf_mnist_example.py | 8 ++++---- python/ray/util/client/server/proxier.py | 2 +- python/ray/util/client/worker.py | 2 +- python/ray/util/dask/scheduler.py | 6 +++--- release/ray_release/anyscale_util.py | 4 ++-- release/ray_release/cluster_manager/full.py | 20 +++++++++---------- .../ray_release/cluster_manager/minimal.py | 10 +++++----- .../ray_release/command_runner/sdk_runner.py | 8 ++++---- rllib/agents/trainer.py | 9 +++------ 14 files changed, 40 insertions(+), 43 deletions(-) diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index 758a1ca57c92..a38a59e6e220 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -260,7 +260,7 @@ async def run( # at the same time assert len(finished) == 1, "Should have only one coroutine done" [child_process_task] = finished - return_code = child_process_task.metrics() + return_code = child_process_task.result() if return_code == 0: self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED) else: diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 59908d9be11e..284cb8ef9c19 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -208,7 +208,7 @@ async def _cache_handler(*args) -> aiohttp.web.Response: def _update_cache(task): try: - response = task.metrics() + response = task.result() except Exception: response = rest_response( success=False, message=traceback.format_exc() diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 1aca9d986d29..12bae7d64dfb 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -246,7 +246,7 @@ def _poll_locked(self, timeout=None) -> None: try: # Use 1s timeout to check for subscriber closing # periodically. - fut.metrics(timeout=1) + fut.result(timeout=1) break except grpc.FutureTimeoutError: # Subscriber has closed. Cancel inflight request and @@ -262,8 +262,8 @@ def _poll_locked(self, timeout=None) -> None: raise if fut.done(): - self._last_batch_size = len(fut.metrics().pub_messages) - for msg in fut.metrics().pub_messages: + self._last_batch_size = len(fut.result().pub_messages) + for msg in fut.result().pub_messages: if msg.channel_type != self._channel: logger.warn(f"Ignoring message from unsubscribed channel {msg}") continue diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index c8f3b9c65300..ac39fe681ecd 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -137,7 +137,7 @@ def test_concurrent_future(ray_start_regular_shared): def cb(fut): nonlocal global_result - global_result = fut.metrics() + global_result = fut.result() fut.add_done_callback(cb) assert global_result == 1 diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 5c34bcea173f..d0d1d2c307d1 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -52,7 +52,7 @@ def test_proxy_manager_lifecycle(shutdown_only): pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server - grpc.channel_ready_future(pm.get_channel(client)).metrics(timeout=5) + grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == free_ports[0], f"Free Ports are: {free_ports}" diff --git a/python/ray/tune/examples/tf_mnist_example.py b/python/ray/tune/examples/tf_mnist_example.py index 1468a18d6d90..e2bde13e4106 100644 --- a/python/ray/tune/examples/tf_mnist_example.py +++ b/python/ray/tune/examples/tf_mnist_example.py @@ -109,10 +109,10 @@ def step(self): # It is important to return tf.Tensors as numpy objects. return { "epoch": self.iteration, - "loss": self.train_loss.metrics().numpy(), - "accuracy": self.train_accuracy.metrics().numpy() * 100, - "test_loss": self.test_loss.metrics().numpy(), - "mean_accuracy": self.test_accuracy.metrics().numpy() * 100, + "loss": self.train_loss.result().numpy(), + "accuracy": self.train_accuracy.result().numpy() * 100, + "test_loss": self.test_loss.result().numpy(), + "mean_accuracy": self.test_accuracy.result().numpy() * 100, } diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index eee6f07e98d2..e83ccd82d6fa 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -370,7 +370,7 @@ def get_channel( # Wait for the SpecificServer to become ready. server.wait_ready() try: - grpc.channel_ready_future(server.channel).metrics( + grpc.channel_ready_future(server.channel).result( timeout=CHECK_CHANNEL_TIMEOUT_S ) return server.channel diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index bc2548db62fe..1f44e0287503 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -216,7 +216,7 @@ def _connect_channel(self, reconnecting=False) -> None: try: # Let gRPC wait for us to see if the channel becomes ready. # If it throws, we couldn't connect. - grpc.channel_ready_future(self.channel).metrics(timeout=timeout) + grpc.channel_ready_future(self.channel).result(timeout=timeout) # The HTTP2 channel is ready. Wrap the channel with the # RayletDriverStub, allowing for unary requests. self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 469770e02435..9c448474436c 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -442,7 +442,7 @@ def render_progress_bar(tracker, object_refs): from tqdm import tqdm # At this time, every task should be submitted. - total, finished = ray.get(tracker.metrics.remote()) + total, finished = ray.get(tracker.result.remote()) reported_finished_so_far = 0 pb_bar = tqdm(total=total, position=0) pb_bar.set_description("") @@ -450,7 +450,7 @@ def render_progress_bar(tracker, object_refs): ready_refs = [] while finished < total: - submitted, finished = ray.get(tracker.metrics.remote()) + submitted, finished = ray.get(tracker.result.remote()) pb_bar.update(finished - reported_finished_so_far) reported_finished_so_far = finished ready_refs, _ = ray.wait( @@ -462,7 +462,7 @@ def render_progress_bar(tracker, object_refs): time.sleep(0.1) pb_bar.close() - submitted, finished = ray.get(tracker.metrics.remote()) + submitted, finished = ray.get(tracker.result.remote()) if submitted != finished: print("Completed. There was state inconsistency.") from pprint import pprint diff --git a/release/ray_release/anyscale_util.py b/release/ray_release/anyscale_util.py index 1055cfdbcd06..2a0ff402dade 100644 --- a/release/ray_release/anyscale_util.py +++ b/release/ray_release/anyscale_util.py @@ -41,11 +41,11 @@ def get_project_name(project_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_project(project_id) - return result.metrics.name + return result.result.name def get_cluster_name(cluster_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_cluster(cluster_id) - return result.metrics.name + return result.result.name diff --git a/release/ray_release/cluster_manager/full.py b/release/ray_release/cluster_manager/full.py index cc91380b5e15..32aafdf2de81 100644 --- a/release/ray_release/cluster_manager/full.py +++ b/release/ray_release/cluster_manager/full.py @@ -37,7 +37,7 @@ def start_cluster(self, timeout: float = 600.0): idle_timeout_minutes=self.autosuspend_minutes, ) ) - self.cluster_id = result.metrics.id + self.cluster_id = result.result.id except Exception as e: raise ClusterCreationError(f"Error creating cluster: {e}") from e @@ -50,8 +50,8 @@ def start_cluster(self, timeout: float = 600.0): try: result = self.sdk.start_cluster(self.cluster_id, start_cluster_options={}) - cop_id = result.metrics.id - completed = result.metrics.completed + cop_id = result.result.id + completed = result.result.completed except Exception as e: raise ClusterStartupError( f"Error starting cluster with name " @@ -88,14 +88,14 @@ def start_cluster(self, timeout: float = 600.0): initial_retry_delay_s=2, max_retries=3, ) - completed = result.metrics.completed + completed = result.result.completed result = self.sdk.get_cluster(self.cluster_id) - if result.metrics.state != "Running": + if result.result.state != "Running": raise ClusterStartupFailed( f"Cluster did not come up - most likely the nodes are currently " f"not available. Please check the cluster startup logs: " - f"{cluster_url} (cluster state: {result.metrics.state})" + f"{cluster_url} (cluster state: {result.result.state})" ) def terminate_cluster(self, wait: bool = False): @@ -109,8 +109,8 @@ def terminate_cluster(self, wait: bool = False): return # Only do this when waiting - cop_id = result.metrics.id - completed = result.metrics.completed + cop_id = result.result.id + completed = result.result.completed while not completed: # Sleep 1 sec before next check. time.sleep(1) @@ -118,10 +118,10 @@ def terminate_cluster(self, wait: bool = False): cluster_operation_response = self.sdk.get_cluster_operation( cop_id, _request_timeout=30 ) - cluster_operation = cluster_operation_response.metrics + cluster_operation = cluster_operation_response.result completed = cluster_operation.completed result = self.sdk.get_cluster(self.cluster_id) - while result.metrics.state != "Terminated": + while result.result.state != "Terminated": time.sleep(1) result = self.sdk.get_cluster(self.cluster_id) diff --git a/release/ray_release/cluster_manager/minimal.py b/release/ray_release/cluster_manager/minimal.py index 1be9c1ed28de..37feee9d5c21 100644 --- a/release/ray_release/cluster_manager/minimal.py +++ b/release/ray_release/cluster_manager/minimal.py @@ -64,7 +64,7 @@ def create_cluster_env(self, _repeat: bool = True): config_json=self.cluster_env, ) ) - self.cluster_env_id = result.metrics.id + self.cluster_env_id = result.result.id except Exception as e: if _repeat: logger.warning( @@ -120,7 +120,7 @@ def build_cluster_env(self, timeout: float = 600.0): cluster_environment_id=self.cluster_env_id, config_json=config_json ) ) - build_id = result.metrics.id + build_id = result.result.id logger.info( f"Link to cluster env build: " @@ -150,7 +150,7 @@ def build_cluster_env(self, timeout: float = 600.0): next_report = next_report + REPORT_S result = self.sdk.get_build(build_id) - build = result.metrics + build = result.result if build.status == "failed": raise ClusterEnvBuildError( @@ -185,7 +185,7 @@ def fetch_build_info(self): assert self.cluster_env_build_id result = self.sdk.get_cluster_environment_build(self.cluster_env_build_id) - self.cluster_env = result.metrics.config_json + self.cluster_env = result.result.config_json def create_cluster_compute(self, _repeat: bool = True): assert self.cluster_compute_id is None @@ -236,7 +236,7 @@ def create_cluster_compute(self, _repeat: bool = True): config=self.cluster_compute, ) ) - self.cluster_compute_id = result.metrics.id + self.cluster_compute_id = result.result.id except Exception as e: if _repeat: logger.warning( diff --git a/release/ray_release/command_runner/sdk_runner.py b/release/ray_release/command_runner/sdk_runner.py index 7d490400835c..f582a601f06c 100644 --- a/release/ray_release/command_runner/sdk_runner.py +++ b/release/ray_release/command_runner/sdk_runner.py @@ -94,10 +94,10 @@ def run_command( dict(session_id=self.cluster_manager.cluster_id, shell_command=full_command) ) - scd_id = result.metrics.id + scd_id = result.result.id self.last_command_scd_id = scd_id - completed = result.metrics.finished_at is not None + completed = result.result.finished_at is not None start_time = time.monotonic() timeout_at = start_time + timeout @@ -126,9 +126,9 @@ def run_command( initial_retry_delay_s=10, max_retries=3, ) - completed = result.metrics.finished_at + completed = result.result.finished_at - status_code = result.metrics.status_code + status_code = result.result.status_code time_taken = time.monotonic() - start_time if status_code != 0: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9090c6e52fdb..b02184214bab 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -40,9 +40,6 @@ ) from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( - MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer, -) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ( @@ -667,7 +664,7 @@ def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done): else: step_results.update(self.evaluate()) # Collect the training results from the future. - step_results.update(train_future.metrics()) + step_results.update(train_future.result()) # Sequential: train (already done above), then eval. else: step_results.update(self.evaluate()) @@ -2127,7 +2124,7 @@ def __setstate__(self, state: dict): @DeveloperAPI def _create_local_replay_buffer_if_necessary( self, config: PartialTrainerConfigDict - ) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]: + ) -> Optional[MultiAgentReplayBuffer]: """Create a MultiAgentReplayBuffer instance if necessary. Args: @@ -2138,7 +2135,7 @@ def _create_local_replay_buffer_if_necessary( None, if local replay buffer is not needed. """ if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( - "no_local_replay_buffer" or config.get("no_local_replay_buffer"), False + "no_local_replay_buffer" or config.get("no_local_replay_buffer") ): return From 8a99f69f592d0a03dd1702c1ece3e070a3b8485e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:43:37 +0100 Subject: [PATCH 41/81] [tune/train] Consolidate checkpoint manager 2: Ray Train --- python/ray/train/checkpoint.py | 311 ++++++++++++++------------------- python/ray/train/trainer.py | 11 +- 2 files changed, 136 insertions(+), 186 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index dd03ed3197eb..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,27 +1,28 @@ -import heapq import logging -import numbers -import os -from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TIMESTAMP, TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import ( + TIMESTAMP, + TRAIN_CHECKPOINT_SUBDIR, + TUNE_CHECKPOINT_FILE_NAME, + TUNE_CHECKPOINT_ID, + TUNE_INSTALLED, +) from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util import PublicAPI -from ray.util.ml_utils.util import is_nan +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager as CommonCheckpointManager, + TrackedCheckpoint, + CheckpointStrategy, +) if TUNE_INSTALLED: from ray import tune else: tune = None -MAX = "max" -MIN = "min" - logger = logging.getLogger(__name__) @@ -34,64 +35,58 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -@PublicAPI(stability="beta") -@dataclass -class CheckpointStrategy: - """Configurable parameters for defining the Train checkpointing strategy. - - Default behavior is to persist all checkpoints to disk. If - ``num_to_keep`` is set, the default retention policy is to keep the - checkpoints with maximum timestamp, i.e. the most recent checkpoints. - - Args: - num_to_keep (Optional[int]): The number of checkpoints to keep - on disk for this run. If a checkpoint is persisted to disk after - there are already this many checkpoints, then an existing - checkpoint will be deleted. If this is ``None`` then checkpoints - will not be deleted. If this is ``0`` then no checkpoints will be - persisted to disk. - checkpoint_score_attribute (str): The attribute that will be used to - score checkpoints to determine which checkpoints should be kept - on disk when there are greater than ``num_to_keep`` checkpoints. - This attribute must be a key from the checkpoint - dictionary which has a numerical value. - checkpoint_score_order (str). Either "max" or "min". - If "max", then checkpoints with highest values of - ``checkpoint_score_attribute`` will be kept. - If "min", then checkpoints with lowest values of - ``checkpoint_score_attribute`` will be kept. +class _NotYetPersistedCheckpoint(TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. """ - num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TIMESTAMP - checkpoint_score_order: str = MAX + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def __post_init__(self): - if self.num_to_keep is not None and self.num_to_keep < 0: - raise ValueError( - f"Received invalidate num_to_keep: " - f"{self.num_to_keep}. " - f"Must be None or non-negative integer." - ) - if self.checkpoint_score_order not in (MAX, MIN): - raise ValueError( - f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' - ) + self._data_to_commit = self.dir_or_data + self.dir_or_data = None + @property + def committed(self) -> bool: + return not self._data_to_commit -class PersistedCheckpoint: - def __init__(self, path, priority): - self.path = path - self.priority = priority + def commit(self, path: Optional[Path] = None): + if self.committed: + return - def __lt__(self, other): - return self.priority < other.priority + assert path - def __repr__(self): - return f"PersistedCheckpoint({repr(self.path)})" + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self._data_to_commit, f) + logger.debug(f"Checkpoint successfully written to: {path}") + self.dir_or_data = path + self._data_to_commit = None -class CheckpointManager: + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): + if not self.committed: + return + return super().delete(delete_fn=delete_fn) + + @classmethod + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): + new_checkpoint = cls( + dir_or_data=checkpoint.dir_or_data, + storage_mode=TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + ) + return new_checkpoint + + +class CheckpointManager(CommonCheckpointManager): """Manages checkpoint processing, writing, and loading. @@ -119,55 +114,16 @@ class CheckpointManager: checkpoint may not be saved to disk. """ - def on_init(self, **kwargs): - """Checkpoint code executed during BackendExecutor init.""" - self.latest_checkpoint = None - - # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = 0 - - # Used for keeping top K checkpoints. - self._top_persisted_checkpoints = [] - - # Best checkpoint altogether. - # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint = None - - def on_start_training( - self, - checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: Path, - latest_checkpoint_id: Optional[int] = None, - ): - """Checkpoint code executed during BackendExecutor start_training.""" - # Restart checkpointing. - self._latest_checkpoint_id = latest_checkpoint_id if latest_checkpoint_id else 0 - self._checkpoint_strategy = ( - CheckpointStrategy() if checkpoint_strategy is None else checkpoint_strategy - ) + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Perform all processing for a checkpoint.""" - - # Get checkpoint from first worker. - checkpoint = checkpoint_results[0].data + super().__init__(checkpoint_strategy=checkpoint_strategy) - # Decode checkpoint. - checkpoint = decode_checkpoint_fn(checkpoint) + self._validate_checkpoint_strategy() - # Store checkpoint in memory. - self.latest_checkpoint = checkpoint - - # Write checkpoint to disk. - self.write_checkpoint(checkpoint) - - # Increment checkpoint id. - self._latest_checkpoint_id += 1 + def _validate_checkpoint_strategy(self): + if self._checkpoint_strategy.checkpoint_score_attribute is None: + self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]] @@ -181,94 +137,77 @@ def _load_checkpoint( # Load checkpoint from path. return load_checkpoint_from_path(checkpoint_to_load) - def write_checkpoint(self, checkpoint: Dict): - """Writes checkpoint to disk.""" - num_to_keep = self._checkpoint_strategy.num_to_keep + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Ray Train entrypoint. Perform all processing for a checkpoint.""" + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data - if num_to_keep == 0: - # Checkpoints should not be persisted to disk. - return + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) - checkpoint_score_attribute = ( - self._checkpoint_strategy.checkpoint_score_attribute - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_attribute not in checkpoint: + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + if ( + self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_data + ): raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " + f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) - checkpoint_score = checkpoint[checkpoint_score_attribute] - if not isinstance(checkpoint_score, numbers.Number): - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute} with value " - f"{checkpoint_score}. " - f"This attribute must be numerical." + tracked_checkpoint = TrackedCheckpoint( + dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + ) + self.register_checkpoint(checkpoint=tracked_checkpoint) + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - def priority(checkpoint_score_order, checkpoint_score): - # Treat NaN as worst - # The tuple structure is (not is_nan(), metric), which makes - # the nan values to be always considered as the worst - # metrics by the heap - if checkpoint_score_order != MAX: - checkpoint_score = -checkpoint_score - return (not is_nan(checkpoint_score), checkpoint_score) + self._latest_checkpoint_id += 1 + + def _get_next_checkpoint_path(self) -> Optional[Path]: + """Path to the next checkpoint to persist.""" + checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + return self.latest_checkpoint_dir.joinpath(checkpoint_file) - checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) + def on_start_training( + self, + checkpoint_strategy: Optional[CheckpointStrategy], + run_dir: str, + latest_checkpoint_id: Optional[int] = 0, + ): + checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self._checkpoint_strategy = checkpoint_strategy - persisted_checkpoint = PersistedCheckpoint( - self.next_checkpoint_path, checkpoint_priority - ) + self._validate_checkpoint_strategy() - def write_to_disk(path: Path): - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(checkpoint, f) - logger.debug(f"Checkpoint successfully written to: " f"{path}") - - def remove_from_disk(path: Path): - os.remove(path) - - if num_to_keep is None: - # Keep all checkpoints. - write_to_disk(self.next_checkpoint_path) - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep first num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) - elif ( - persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority - ): - # Keep top num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, persisted_checkpoint - ) - worst_checkpoint_path = worst_checkpoint.path - remove_from_disk(worst_checkpoint_path) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") - else: - # If the latest checkpoint has the same or lower priority, skip it. - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." - ) + self.run_dir = run_dir + self._latest_checkpoint_id = latest_checkpoint_id or 0 - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint = persisted_checkpoint + # Train-specific attributes + @property + def latest_checkpoint(self): + if not self._latest_memory_checkpoint: + return None + return self._latest_memory_checkpoint.dir_or_data @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -294,7 +233,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return self._best_persisted_checkpoint.path + return Path(self._best_persisted_checkpoint.dir_or_data) else: return None @@ -328,16 +267,22 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - with file_path.open("wb") as f: - cloudpickle.dump(checkpoint, f) + checkpoint.commit(file_path) + + return super()._decide_what_to_do_with_checkpoint(checkpoint) def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 798dd99454a0..efb39e9c3535 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,11 +224,16 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) + # Todo (krfricke): Initialize checkpoint manager here with final values + # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager() + self.checkpoint_manager = TuneCheckpointManager( + checkpoint_strategy=None, run_dir=None + ) else: - self.checkpoint_manager = CheckpointManager() - self.checkpoint_manager.on_init() + self.checkpoint_manager = CheckpointManager( + checkpoint_strategy=None, run_dir=None + ) def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From 7e48c0b3b70fcfddda39a4c134edf9e2114db22d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:40:35 +0100 Subject: [PATCH 42/81] Adjust to changes in base PR --- python/ray/train/checkpoint.py | 81 +++----------------------- python/ray/train/tests/test_trainer.py | 1 + 2 files changed, 8 insertions(+), 74 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ffbddb5bd4a8..1bc2349ff5b7 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray import cloudpickle +from ray.ml import Checkpoint from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, @@ -31,59 +31,8 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - with checkpoint_path.open("rb") as f: - return cloudpickle.load(f) - - -class _NotYetPersistedCheckpoint(TrackedCheckpoint): - """Tracked checkpoint that is not yet persisted to disk. - - This checkpoint class supports lazy writing. The checkpoint manager will - only call ``commit()`` if the checkpoint should be kept on disk. This class - will only then write checkpoint data to disk. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._data_to_commit = self.dir_or_data - self.dir_or_data = None - - @property - def committed(self) -> bool: - return not self._data_to_commit - - def commit(self, path: Optional[Path] = None): - if self.committed: - return - - assert path - - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(self._data_to_commit, f) - logger.debug(f"Checkpoint successfully written to: {path}") - - self.dir_or_data = path - self._data_to_commit = None - - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - if not self.committed: - return - return super().delete(delete_fn=delete_fn) - - @classmethod - def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): - new_checkpoint = cls( - dir_or_data=checkpoint.dir_or_data, - storage_mode=TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.id, - result=checkpoint.result, - node_ip=checkpoint.node_ip, - ) - return new_checkpoint + checkpoint = Checkpoint.from_directory(str(checkpoint_path)) + return checkpoint.to_dict() class CheckpointManager(CommonCheckpointManager): @@ -114,6 +63,8 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ + _persist_memory_checkpoints = True + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir @@ -166,23 +117,10 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(checkpoint) - - # Only process further if we consider keeping this checkpoint on disk - if self._checkpoint_strategy.num_to_keep != 0: - not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) - ) - self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - - self._latest_checkpoint_id += 1 - def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) @@ -267,12 +205,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint( - self, checkpoint: _NotYetPersistedCheckpoint - ): - assert isinstance(checkpoint, _NotYetPersistedCheckpoint) - assert not checkpoint.committed - + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index dc72e21a939f..cc017de709b9 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -15,6 +15,7 @@ from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV from ray.train.torch import TorchConfig from ray.train.tensorflow import TensorflowConfig + from ray.train.horovod import HorovodConfig from ray.train.callbacks.callback import TrainingCallback from ray.train.worker_group import WorkerGroup From 51fa58425e08f74ca626956179a9c42a69145dc2 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:59:48 +0100 Subject: [PATCH 43/81] Fix faulty rebase --- python/ray/train/trainer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index efb39e9c3535..798dd99454a0 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,16 +224,11 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) - # Todo (krfricke): Initialize checkpoint manager here with final values - # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager( - checkpoint_strategy=None, run_dir=None - ) + self.checkpoint_manager = TuneCheckpointManager() else: - self.checkpoint_manager = CheckpointManager( - checkpoint_strategy=None, run_dir=None - ) + self.checkpoint_manager = CheckpointManager() + self.checkpoint_manager.on_init() def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From 731aef3562da5774f98d9ea6cd2052e5a1720ac6 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:01:09 +0100 Subject: [PATCH 44/81] Fix faulty rebase --- dashboard/modules/job/job_manager.py | 2 +- dashboard/optional_utils.py | 2 +- python/ray/_private/gcs_pubsub.py | 6 +++--- python/ray/tests/test_async.py | 2 +- python/ray/tests/test_client_proxy.py | 2 +- python/ray/util/client/server/proxier.py | 2 +- python/ray/util/client/worker.py | 2 +- python/ray/util/dask/scheduler.py | 6 +++--- release/ray_release/anyscale_util.py | 4 ++-- release/ray_release/cluster_manager/full.py | 20 +++++++++---------- .../ray_release/cluster_manager/minimal.py | 10 +++++----- .../ray_release/command_runner/sdk_runner.py | 8 ++++---- rllib/agents/trainer.py | 9 +++------ 13 files changed, 36 insertions(+), 39 deletions(-) diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index 758a1ca57c92..a38a59e6e220 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -260,7 +260,7 @@ async def run( # at the same time assert len(finished) == 1, "Should have only one coroutine done" [child_process_task] = finished - return_code = child_process_task.metrics() + return_code = child_process_task.result() if return_code == 0: self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED) else: diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 59908d9be11e..284cb8ef9c19 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -208,7 +208,7 @@ async def _cache_handler(*args) -> aiohttp.web.Response: def _update_cache(task): try: - response = task.metrics() + response = task.result() except Exception: response = rest_response( success=False, message=traceback.format_exc() diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 1aca9d986d29..12bae7d64dfb 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -246,7 +246,7 @@ def _poll_locked(self, timeout=None) -> None: try: # Use 1s timeout to check for subscriber closing # periodically. - fut.metrics(timeout=1) + fut.result(timeout=1) break except grpc.FutureTimeoutError: # Subscriber has closed. Cancel inflight request and @@ -262,8 +262,8 @@ def _poll_locked(self, timeout=None) -> None: raise if fut.done(): - self._last_batch_size = len(fut.metrics().pub_messages) - for msg in fut.metrics().pub_messages: + self._last_batch_size = len(fut.result().pub_messages) + for msg in fut.result().pub_messages: if msg.channel_type != self._channel: logger.warn(f"Ignoring message from unsubscribed channel {msg}") continue diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index c8f3b9c65300..ac39fe681ecd 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -137,7 +137,7 @@ def test_concurrent_future(ray_start_regular_shared): def cb(fut): nonlocal global_result - global_result = fut.metrics() + global_result = fut.result() fut.add_done_callback(cb) assert global_result == 1 diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 5c34bcea173f..d0d1d2c307d1 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -52,7 +52,7 @@ def test_proxy_manager_lifecycle(shutdown_only): pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server - grpc.channel_ready_future(pm.get_channel(client)).metrics(timeout=5) + grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == free_ports[0], f"Free Ports are: {free_ports}" diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index eee6f07e98d2..e83ccd82d6fa 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -370,7 +370,7 @@ def get_channel( # Wait for the SpecificServer to become ready. server.wait_ready() try: - grpc.channel_ready_future(server.channel).metrics( + grpc.channel_ready_future(server.channel).result( timeout=CHECK_CHANNEL_TIMEOUT_S ) return server.channel diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index bc2548db62fe..1f44e0287503 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -216,7 +216,7 @@ def _connect_channel(self, reconnecting=False) -> None: try: # Let gRPC wait for us to see if the channel becomes ready. # If it throws, we couldn't connect. - grpc.channel_ready_future(self.channel).metrics(timeout=timeout) + grpc.channel_ready_future(self.channel).result(timeout=timeout) # The HTTP2 channel is ready. Wrap the channel with the # RayletDriverStub, allowing for unary requests. self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 469770e02435..9c448474436c 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -442,7 +442,7 @@ def render_progress_bar(tracker, object_refs): from tqdm import tqdm # At this time, every task should be submitted. - total, finished = ray.get(tracker.metrics.remote()) + total, finished = ray.get(tracker.result.remote()) reported_finished_so_far = 0 pb_bar = tqdm(total=total, position=0) pb_bar.set_description("") @@ -450,7 +450,7 @@ def render_progress_bar(tracker, object_refs): ready_refs = [] while finished < total: - submitted, finished = ray.get(tracker.metrics.remote()) + submitted, finished = ray.get(tracker.result.remote()) pb_bar.update(finished - reported_finished_so_far) reported_finished_so_far = finished ready_refs, _ = ray.wait( @@ -462,7 +462,7 @@ def render_progress_bar(tracker, object_refs): time.sleep(0.1) pb_bar.close() - submitted, finished = ray.get(tracker.metrics.remote()) + submitted, finished = ray.get(tracker.result.remote()) if submitted != finished: print("Completed. There was state inconsistency.") from pprint import pprint diff --git a/release/ray_release/anyscale_util.py b/release/ray_release/anyscale_util.py index 1055cfdbcd06..2a0ff402dade 100644 --- a/release/ray_release/anyscale_util.py +++ b/release/ray_release/anyscale_util.py @@ -41,11 +41,11 @@ def get_project_name(project_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_project(project_id) - return result.metrics.name + return result.result.name def get_cluster_name(cluster_id: str, sdk: Optional[AnyscaleSDK] = None) -> str: sdk = sdk or get_anyscale_sdk() result = sdk.get_cluster(cluster_id) - return result.metrics.name + return result.result.name diff --git a/release/ray_release/cluster_manager/full.py b/release/ray_release/cluster_manager/full.py index cc91380b5e15..32aafdf2de81 100644 --- a/release/ray_release/cluster_manager/full.py +++ b/release/ray_release/cluster_manager/full.py @@ -37,7 +37,7 @@ def start_cluster(self, timeout: float = 600.0): idle_timeout_minutes=self.autosuspend_minutes, ) ) - self.cluster_id = result.metrics.id + self.cluster_id = result.result.id except Exception as e: raise ClusterCreationError(f"Error creating cluster: {e}") from e @@ -50,8 +50,8 @@ def start_cluster(self, timeout: float = 600.0): try: result = self.sdk.start_cluster(self.cluster_id, start_cluster_options={}) - cop_id = result.metrics.id - completed = result.metrics.completed + cop_id = result.result.id + completed = result.result.completed except Exception as e: raise ClusterStartupError( f"Error starting cluster with name " @@ -88,14 +88,14 @@ def start_cluster(self, timeout: float = 600.0): initial_retry_delay_s=2, max_retries=3, ) - completed = result.metrics.completed + completed = result.result.completed result = self.sdk.get_cluster(self.cluster_id) - if result.metrics.state != "Running": + if result.result.state != "Running": raise ClusterStartupFailed( f"Cluster did not come up - most likely the nodes are currently " f"not available. Please check the cluster startup logs: " - f"{cluster_url} (cluster state: {result.metrics.state})" + f"{cluster_url} (cluster state: {result.result.state})" ) def terminate_cluster(self, wait: bool = False): @@ -109,8 +109,8 @@ def terminate_cluster(self, wait: bool = False): return # Only do this when waiting - cop_id = result.metrics.id - completed = result.metrics.completed + cop_id = result.result.id + completed = result.result.completed while not completed: # Sleep 1 sec before next check. time.sleep(1) @@ -118,10 +118,10 @@ def terminate_cluster(self, wait: bool = False): cluster_operation_response = self.sdk.get_cluster_operation( cop_id, _request_timeout=30 ) - cluster_operation = cluster_operation_response.metrics + cluster_operation = cluster_operation_response.result completed = cluster_operation.completed result = self.sdk.get_cluster(self.cluster_id) - while result.metrics.state != "Terminated": + while result.result.state != "Terminated": time.sleep(1) result = self.sdk.get_cluster(self.cluster_id) diff --git a/release/ray_release/cluster_manager/minimal.py b/release/ray_release/cluster_manager/minimal.py index 1be9c1ed28de..37feee9d5c21 100644 --- a/release/ray_release/cluster_manager/minimal.py +++ b/release/ray_release/cluster_manager/minimal.py @@ -64,7 +64,7 @@ def create_cluster_env(self, _repeat: bool = True): config_json=self.cluster_env, ) ) - self.cluster_env_id = result.metrics.id + self.cluster_env_id = result.result.id except Exception as e: if _repeat: logger.warning( @@ -120,7 +120,7 @@ def build_cluster_env(self, timeout: float = 600.0): cluster_environment_id=self.cluster_env_id, config_json=config_json ) ) - build_id = result.metrics.id + build_id = result.result.id logger.info( f"Link to cluster env build: " @@ -150,7 +150,7 @@ def build_cluster_env(self, timeout: float = 600.0): next_report = next_report + REPORT_S result = self.sdk.get_build(build_id) - build = result.metrics + build = result.result if build.status == "failed": raise ClusterEnvBuildError( @@ -185,7 +185,7 @@ def fetch_build_info(self): assert self.cluster_env_build_id result = self.sdk.get_cluster_environment_build(self.cluster_env_build_id) - self.cluster_env = result.metrics.config_json + self.cluster_env = result.result.config_json def create_cluster_compute(self, _repeat: bool = True): assert self.cluster_compute_id is None @@ -236,7 +236,7 @@ def create_cluster_compute(self, _repeat: bool = True): config=self.cluster_compute, ) ) - self.cluster_compute_id = result.metrics.id + self.cluster_compute_id = result.result.id except Exception as e: if _repeat: logger.warning( diff --git a/release/ray_release/command_runner/sdk_runner.py b/release/ray_release/command_runner/sdk_runner.py index 7d490400835c..f582a601f06c 100644 --- a/release/ray_release/command_runner/sdk_runner.py +++ b/release/ray_release/command_runner/sdk_runner.py @@ -94,10 +94,10 @@ def run_command( dict(session_id=self.cluster_manager.cluster_id, shell_command=full_command) ) - scd_id = result.metrics.id + scd_id = result.result.id self.last_command_scd_id = scd_id - completed = result.metrics.finished_at is not None + completed = result.result.finished_at is not None start_time = time.monotonic() timeout_at = start_time + timeout @@ -126,9 +126,9 @@ def run_command( initial_retry_delay_s=10, max_retries=3, ) - completed = result.metrics.finished_at + completed = result.result.finished_at - status_code = result.metrics.status_code + status_code = result.result.status_code time_taken = time.monotonic() - start_time if status_code != 0: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9090c6e52fdb..b02184214bab 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -40,9 +40,6 @@ ) from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( - MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer, -) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ( @@ -667,7 +664,7 @@ def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done): else: step_results.update(self.evaluate()) # Collect the training results from the future. - step_results.update(train_future.metrics()) + step_results.update(train_future.result()) # Sequential: train (already done above), then eval. else: step_results.update(self.evaluate()) @@ -2127,7 +2124,7 @@ def __setstate__(self, state: dict): @DeveloperAPI def _create_local_replay_buffer_if_necessary( self, config: PartialTrainerConfigDict - ) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]: + ) -> Optional[MultiAgentReplayBuffer]: """Create a MultiAgentReplayBuffer instance if necessary. Args: @@ -2138,7 +2135,7 @@ def _create_local_replay_buffer_if_necessary( None, if local replay buffer is not needed. """ if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( - "no_local_replay_buffer" or config.get("no_local_replay_buffer"), False + "no_local_replay_buffer" or config.get("no_local_replay_buffer") ): return From c76e0b17ed80b0bc5578b7a95c5fe1c0569b939e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:03:36 +0100 Subject: [PATCH 45/81] Rename results - metrics --- python/ray/tune/ray_trial_executor.py | 4 +-- .../ray/tune/tests/test_checkpoint_manager.py | 30 +++++++++---------- python/ray/tune/tests/test_trial_runner_2.py | 2 +- .../tune/tests/test_trial_runner_callbacks.py | 2 +- python/ray/tune/tests/test_trial_scheduler.py | 4 +-- .../tune/tests/test_trial_scheduler_pbt.py | 2 +- .../test_trial_scheduler_resource_changing.py | 2 +- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index cdf24a982c8f..e215d8233a7c 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -693,13 +693,13 @@ def save( if storage == TrackedCheckpoint.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = TrackedCheckpoint( - dir_or_data=value, storage_mode=storage, result=result + dir_or_data=value, storage_mode=storage, metrics=result ) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() checkpoint = TrackedCheckpoint( - dir_or_data=value, storage_mode=storage, result=result + dir_or_data=value, storage_mode=storage, metrics=result ) trial.saving_to = checkpoint self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index cb7d414a4056..14679640aa27 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -28,13 +28,13 @@ def testNewestCheckpoint(self): memory_checkpoint = TrackedCheckpoint( dir_or_data={0}, storage_mode=TrackedCheckpoint.MEMORY, - result=self.mock_result(0, 0), + metrics=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = TrackedCheckpoint( dir_or_data={1}, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(1, 1), + metrics=self.mock_result(1, 1), ) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual( @@ -52,7 +52,7 @@ def testOnCheckpointOrdered(self): TrackedCheckpoint( dir_or_data={i}, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(i, i), + metrics=self.mock_result(i, i), ) for i in range(3) ] @@ -84,7 +84,7 @@ def testOnCheckpointUnordered(self): TrackedCheckpoint( dir_or_data={i}, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(i, i), + metrics=self.mock_result(i, i), ) for i in range(3, -1, -1) ] @@ -118,7 +118,7 @@ def testBestCheckpoints(self): TrackedCheckpoint( dir_or_data=i, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(i, i), + metrics=self.mock_result(i, i), ) for i in range(8) ] @@ -143,14 +143,14 @@ def testBestCheckpointsWithNan(self): TrackedCheckpoint( dir_or_data=None, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(float("nan"), i), + metrics=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ TrackedCheckpoint( dir_or_data=3, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(0, 3), + metrics=self.mock_result(0, 3), ) ] @@ -175,7 +175,7 @@ def testBestCheckpointsOnlyNan(self): TrackedCheckpoint( dir_or_data=i, storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(float("nan"), i), + metrics=self.mock_result(float("nan"), i), ) for i in range(4) ] @@ -199,7 +199,7 @@ def testOnCheckpointUnavailableAttribute(self): no_attr_checkpoint = TrackedCheckpoint( dir_or_data=0, storage_mode=TrackedCheckpoint.PERSISTENT, - result={}, + metrics={}, ) with patch.object(logger, "error") as log_error_mock: @@ -215,12 +215,12 @@ def testOnMemoryCheckpoint(self): TrackedCheckpoint( dir_or_data=0, storage_mode=TrackedCheckpoint.MEMORY, - result=self.mock_result(0, 0), + metrics=self.mock_result(0, 0), ), TrackedCheckpoint( dir_or_data=0, storage_mode=TrackedCheckpoint.MEMORY, - result=self.mock_result(0, 0), + metrics=self.mock_result(0, 0), ), ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) @@ -249,22 +249,22 @@ def testSameCheckpoint(self): TrackedCheckpoint( dir_or_data=tmpfiles[0], storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(5, 5), + metrics=self.mock_result(5, 5), ), TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(10, 10), + metrics=self.mock_result(10, 10), ), TrackedCheckpoint( dir_or_data=tmpfiles[2], storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(0, 0), + metrics=self.mock_result(0, 0), ), TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=TrackedCheckpoint.PERSISTENT, - result=self.mock_result(20, 20), + metrics=self.mock_result(20, 20), ), ] for checkpoint in checkpoints: diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 4d69c94e46d0..d0a5dc084ad5 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -336,7 +336,7 @@ def write_checkpoint(trial: Trial, index: int): tune_cp = TrackedCheckpoint( dir_or_data=checkpoint_dir, storage_mode=TrackedCheckpoint.PERSISTENT, - result=result, + metrics=result, ) trial.saving_to = tune_cp diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 7f9d4b97383f..4ae7a7da2a3c 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -153,7 +153,7 @@ def testCallbackSteps(self): cp = TrackedCheckpoint( dir_or_data=ray.put(1), storage_mode=TrackedCheckpoint.PERSISTENT, - result={TRAINING_ITERATION: 0}, + metrics={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 98e1c93d519c..0af6fa1f084f 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -253,7 +253,7 @@ def save(self, trial, type=TrackedCheckpoint.PERSISTENT, result=None): return TrackedCheckpoint( dir_or_data=trial.trainable_name, storage_mode=TrackedCheckpoint.PERSISTENT, - result=result, + metrics=result, ) def reset_trial(self, trial, new_config, new_experiment_tag): @@ -850,7 +850,7 @@ def checkpoint(self): return TrackedCheckpoint( dir_or_data=self.trainable_name, storage_mode=TrackedCheckpoint.MEMORY, - result=None, + metrics=None, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index f96e98cfe803..542fef127e40 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -443,7 +443,7 @@ def checkpoint(self): return TrackedCheckpoint( dir_or_data="None", storage_mode=TrackedCheckpoint.MEMORY, - result={}, + metrics={}, ) @property diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index 43fedd3ce095..a0dfcc843ff5 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -52,7 +52,7 @@ def checkpoint(self): return TrackedCheckpoint( dir_or_data="None", storage_mode=TrackedCheckpoint.MEMORY, - result={}, + metrics={}, ) From cdb9d491c8205d71a712c8adeba3825cc2e9b916 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:04:39 +0100 Subject: [PATCH 46/81] fix delete fn --- python/ray/tune/checkpoint_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 1c1ff068d9c3..89ba744a5fdc 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -138,9 +138,9 @@ def __getstate__(self): TrackedCheckpoint.MEMORY, None ) # Avoid serializing lambda since it may capture cyclical dependencies. - state.pop("delete") + state.pop("_delete_fn") return state def __setstate__(self, state): self.__dict__.update(state) - self.delete = None + self._delete_fn = None From 6ce138ba92dfef7c84674184eb481efcc7d693cf Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:41:35 +0100 Subject: [PATCH 47/81] [tune/train] Consolidate checkpoint manager 1: Common checkpoint manager class --- .../ray/util/ml_utils/checkpoint_manager.py | 80 ++++--------------- 1 file changed, 14 insertions(+), 66 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 2334d69ad7cd..f97068b58377 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,4 +1,3 @@ -import copy import gc import heapq import logging @@ -11,7 +10,6 @@ from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray -from ray.ml import Checkpoint from ray.tune.result import NODE_IP from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI @@ -31,21 +29,12 @@ class TrackedCheckpoint: order to add metadata (e.g. the result, or the node where it has been created) and for bookkeeping purposes. - The data can be an object, a checkpoint directory, or a future to either. Because - we can't know if it's data or a directory from a future, this class expects - a ``storage_mode`` that makes the data type explicit. - - The passed metrics can be used to compare performance of different checkpoints. - The ``checkpoint_id`` is passed as an alternative to be able to order - checkpoints in time. - Args: dir_or_data: Checkpoint directory, checkpoint data, or a future to either. storage_mode: Either MEMORY or PERSISTENT. - checkpoint_id: Checkpoint number. Will be used to determine checkpoint order - if metrics are not available. Usually this should be monotonically + checkpoint_id: Checkpoint number. Usually this should be monotonically increasing for each tracked checkpoint. - metrics: Observed metrics for this checkpoint. This is used to determine + result: Observed metrics for this checkpoint. This is used to determine the value of the ``checkpoint_score_attr``. node_ip: IP of the node where the checkpoint was generated. Defaults to the current node. @@ -59,15 +48,16 @@ def __init__( dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, checkpoint_id: Optional[int] = None, - metrics: Optional[Dict] = None, + result: Optional[Dict] = None, node_ip: Optional[str] = None, ): self.dir_or_data = dir_or_data self.id = checkpoint_id self.storage_mode = storage_mode - self.metrics = metrics or {} - self.node_ip = node_ip or self.metrics.get(NODE_IP, None) + # Todo: What to do if result is a subset of dir_or_data (dict) + self.result = result or {} + self.node_ip = node_ip or self.result.get(NODE_IP, None) def commit(self, path: Optional[Path] = None) -> None: """Commit checkpoint to disk, if needed. @@ -75,20 +65,7 @@ def commit(self, path: Optional[Path] = None) -> None: Args: path: Path to commit checkpoint to. """ - if self.storage_mode == TrackedCheckpoint.MEMORY: - # Do not persist memory checkpoints - return - - if not isinstance(self.dir_or_data, dict): - # Only persist dictionaries - return - - if not path: - # If no path is given, skip - return - - checkpoint = Checkpoint.from_dict(self.dir_or_data) - self.dir_or_data = checkpoint.to_directory(str(path)) + pass def delete( self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None @@ -107,7 +84,7 @@ def delete( def __repr__(self): if self.storage_mode == TrackedCheckpoint.MEMORY: - return f"" + return f"" return ( f" Date: Fri, 13 May 2022 11:43:37 +0100 Subject: [PATCH 48/81] [tune/train] Consolidate checkpoint manager 2: Ray Train --- python/ray/train/checkpoint.py | 311 ++++++++++++++------------------- python/ray/train/trainer.py | 11 +- 2 files changed, 136 insertions(+), 186 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index dd03ed3197eb..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -1,27 +1,28 @@ -import heapq import logging -import numbers -import os -from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Dict, Union, Callable from ray import cloudpickle -from ray.train.constants import TIMESTAMP, TUNE_INSTALLED, TRAIN_CHECKPOINT_SUBDIR -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID +from ray.train.constants import ( + TIMESTAMP, + TRAIN_CHECKPOINT_SUBDIR, + TUNE_CHECKPOINT_FILE_NAME, + TUNE_CHECKPOINT_ID, + TUNE_INSTALLED, +) from ray.train.session import TrainingResult from ray.train.utils import construct_path -from ray.util import PublicAPI -from ray.util.ml_utils.util import is_nan +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager as CommonCheckpointManager, + TrackedCheckpoint, + CheckpointStrategy, +) if TUNE_INSTALLED: from ray import tune else: tune = None -MAX = "max" -MIN = "min" - logger = logging.getLogger(__name__) @@ -34,64 +35,58 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: return cloudpickle.load(f) -@PublicAPI(stability="beta") -@dataclass -class CheckpointStrategy: - """Configurable parameters for defining the Train checkpointing strategy. - - Default behavior is to persist all checkpoints to disk. If - ``num_to_keep`` is set, the default retention policy is to keep the - checkpoints with maximum timestamp, i.e. the most recent checkpoints. - - Args: - num_to_keep (Optional[int]): The number of checkpoints to keep - on disk for this run. If a checkpoint is persisted to disk after - there are already this many checkpoints, then an existing - checkpoint will be deleted. If this is ``None`` then checkpoints - will not be deleted. If this is ``0`` then no checkpoints will be - persisted to disk. - checkpoint_score_attribute (str): The attribute that will be used to - score checkpoints to determine which checkpoints should be kept - on disk when there are greater than ``num_to_keep`` checkpoints. - This attribute must be a key from the checkpoint - dictionary which has a numerical value. - checkpoint_score_order (str). Either "max" or "min". - If "max", then checkpoints with highest values of - ``checkpoint_score_attribute`` will be kept. - If "min", then checkpoints with lowest values of - ``checkpoint_score_attribute`` will be kept. +class _NotYetPersistedCheckpoint(TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. """ - num_to_keep: Optional[int] = None - checkpoint_score_attribute: str = TIMESTAMP - checkpoint_score_order: str = MAX + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def __post_init__(self): - if self.num_to_keep is not None and self.num_to_keep < 0: - raise ValueError( - f"Received invalidate num_to_keep: " - f"{self.num_to_keep}. " - f"Must be None or non-negative integer." - ) - if self.checkpoint_score_order not in (MAX, MIN): - raise ValueError( - f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' - ) + self._data_to_commit = self.dir_or_data + self.dir_or_data = None + @property + def committed(self) -> bool: + return not self._data_to_commit -class PersistedCheckpoint: - def __init__(self, path, priority): - self.path = path - self.priority = priority + def commit(self, path: Optional[Path] = None): + if self.committed: + return - def __lt__(self, other): - return self.priority < other.priority + assert path - def __repr__(self): - return f"PersistedCheckpoint({repr(self.path)})" + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self._data_to_commit, f) + logger.debug(f"Checkpoint successfully written to: {path}") + self.dir_or_data = path + self._data_to_commit = None -class CheckpointManager: + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): + if not self.committed: + return + return super().delete(delete_fn=delete_fn) + + @classmethod + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): + new_checkpoint = cls( + dir_or_data=checkpoint.dir_or_data, + storage_mode=TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + ) + return new_checkpoint + + +class CheckpointManager(CommonCheckpointManager): """Manages checkpoint processing, writing, and loading. @@ -119,55 +114,16 @@ class CheckpointManager: checkpoint may not be saved to disk. """ - def on_init(self, **kwargs): - """Checkpoint code executed during BackendExecutor init.""" - self.latest_checkpoint = None - - # Incremental unique checkpoint ID of this run. - self._latest_checkpoint_id = 0 - - # Used for keeping top K checkpoints. - self._top_persisted_checkpoints = [] - - # Best checkpoint altogether. - # Used for exposing best_checkpoint_path. - self._best_persisted_checkpoint = None - - def on_start_training( - self, - checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: Path, - latest_checkpoint_id: Optional[int] = None, - ): - """Checkpoint code executed during BackendExecutor start_training.""" - # Restart checkpointing. - self._latest_checkpoint_id = latest_checkpoint_id if latest_checkpoint_id else 0 - self._checkpoint_strategy = ( - CheckpointStrategy() if checkpoint_strategy is None else checkpoint_strategy - ) + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir - def _process_checkpoint( - self, - checkpoint_results: List[TrainingResult], - decode_checkpoint_fn: Callable, - ) -> None: - """Perform all processing for a checkpoint.""" - - # Get checkpoint from first worker. - checkpoint = checkpoint_results[0].data + super().__init__(checkpoint_strategy=checkpoint_strategy) - # Decode checkpoint. - checkpoint = decode_checkpoint_fn(checkpoint) + self._validate_checkpoint_strategy() - # Store checkpoint in memory. - self.latest_checkpoint = checkpoint - - # Write checkpoint to disk. - self.write_checkpoint(checkpoint) - - # Increment checkpoint id. - self._latest_checkpoint_id += 1 + def _validate_checkpoint_strategy(self): + if self._checkpoint_strategy.checkpoint_score_attribute is None: + self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]] @@ -181,94 +137,77 @@ def _load_checkpoint( # Load checkpoint from path. return load_checkpoint_from_path(checkpoint_to_load) - def write_checkpoint(self, checkpoint: Dict): - """Writes checkpoint to disk.""" - num_to_keep = self._checkpoint_strategy.num_to_keep + def _process_checkpoint( + self, + checkpoint_results: List[TrainingResult], + decode_checkpoint_fn: Callable, + ) -> None: + """Ray Train entrypoint. Perform all processing for a checkpoint.""" + # Get checkpoint from first worker. + checkpoint_data = checkpoint_results[0].data - if num_to_keep == 0: - # Checkpoints should not be persisted to disk. - return + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) - checkpoint_score_attribute = ( - self._checkpoint_strategy.checkpoint_score_attribute - ) - checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order - if checkpoint_score_attribute not in checkpoint: + score_attr = self._checkpoint_strategy.checkpoint_score_attribute + if ( + self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_data + ): raise ValueError( f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute}. " + f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint." ) - checkpoint_score = checkpoint[checkpoint_score_attribute] - if not isinstance(checkpoint_score, numbers.Number): - raise ValueError( - f"Unable to persist checkpoint for " - f"checkpoint_score_attribute: " - f"{checkpoint_score_attribute} with value " - f"{checkpoint_score}. " - f"This attribute must be numerical." + tracked_checkpoint = TrackedCheckpoint( + dir_or_data=checkpoint_data, + checkpoint_id=self._latest_checkpoint_id, + storage_mode=TrackedCheckpoint.MEMORY, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + ) + self.register_checkpoint(checkpoint=tracked_checkpoint) + + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - def priority(checkpoint_score_order, checkpoint_score): - # Treat NaN as worst - # The tuple structure is (not is_nan(), metric), which makes - # the nan values to be always considered as the worst - # metrics by the heap - if checkpoint_score_order != MAX: - checkpoint_score = -checkpoint_score - return (not is_nan(checkpoint_score), checkpoint_score) + self._latest_checkpoint_id += 1 + + def _get_next_checkpoint_path(self) -> Optional[Path]: + """Path to the next checkpoint to persist.""" + checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + return self.latest_checkpoint_dir.joinpath(checkpoint_file) - checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) + def on_start_training( + self, + checkpoint_strategy: Optional[CheckpointStrategy], + run_dir: str, + latest_checkpoint_id: Optional[int] = 0, + ): + checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + self._checkpoint_strategy = checkpoint_strategy - persisted_checkpoint = PersistedCheckpoint( - self.next_checkpoint_path, checkpoint_priority - ) + self._validate_checkpoint_strategy() - def write_to_disk(path: Path): - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(checkpoint, f) - logger.debug(f"Checkpoint successfully written to: " f"{path}") - - def remove_from_disk(path: Path): - os.remove(path) - - if num_to_keep is None: - # Keep all checkpoints. - write_to_disk(self.next_checkpoint_path) - elif len(self._top_persisted_checkpoints) < num_to_keep: - # Keep first num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - heapq.heappush(self._top_persisted_checkpoints, persisted_checkpoint) - elif ( - persisted_checkpoint.priority > self._top_persisted_checkpoints[0].priority - ): - # Keep top num_to_keep checkpoints. - write_to_disk(self.next_checkpoint_path) - worst_checkpoint = heapq.heappushpop( - self._top_persisted_checkpoints, persisted_checkpoint - ) - worst_checkpoint_path = worst_checkpoint.path - remove_from_disk(worst_checkpoint_path) - logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint_path}.") - else: - # If the latest checkpoint has the same or lower priority, skip it. - logger.debug( - f"Skipping checkpoint due to low score:" f"{self.next_checkpoint_path}." - ) + self.run_dir = run_dir + self._latest_checkpoint_id = latest_checkpoint_id or 0 - # Update single best checkpoint. - if ( - self._best_persisted_checkpoint is None - or persisted_checkpoint.priority > self._best_persisted_checkpoint.priority - ): - # If the latest checkpoint has the same or lower priority, skip it. - self._best_persisted_checkpoint = persisted_checkpoint + # Train-specific attributes + @property + def latest_checkpoint(self): + if not self._latest_memory_checkpoint: + return None + return self._latest_memory_checkpoint.dir_or_data @property def latest_checkpoint_dir(self) -> Optional[Path]: @@ -294,7 +233,7 @@ def next_checkpoint_path(self) -> Optional[Path]: def best_checkpoint_path(self) -> Optional[Path]: """Path to the best persisted checkpoint.""" if self._best_persisted_checkpoint: - return self._best_persisted_checkpoint.path + return Path(self._best_persisted_checkpoint.dir_or_data) else: return None @@ -328,16 +267,22 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) # Use a standard file name so that we know which file to load # the checkpoint from. file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - with file_path.open("wb") as f: - cloudpickle.dump(checkpoint, f) + checkpoint.commit(file_path) + + return super()._decide_what_to_do_with_checkpoint(checkpoint) def construct_checkpoint_file_name(checkpoint_id: int) -> str: diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 798dd99454a0..efb39e9c3535 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,11 +224,16 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) + # Todo (krfricke): Initialize checkpoint manager here with final values + # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager() + self.checkpoint_manager = TuneCheckpointManager( + checkpoint_strategy=None, run_dir=None + ) else: - self.checkpoint_manager = CheckpointManager() - self.checkpoint_manager.on_init() + self.checkpoint_manager = CheckpointManager( + checkpoint_strategy=None, run_dir=None + ) def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From bba91f327242cf38a804cc2f300c57aafe5599a6 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:40:35 +0100 Subject: [PATCH 49/81] Adjust to changes in base PR --- python/ray/train/checkpoint.py | 81 +++----------------------- python/ray/train/tests/test_trainer.py | 1 + 2 files changed, 8 insertions(+), 74 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ffbddb5bd4a8..1bc2349ff5b7 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray import cloudpickle +from ray.ml import Checkpoint from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, @@ -31,59 +31,8 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - with checkpoint_path.open("rb") as f: - return cloudpickle.load(f) - - -class _NotYetPersistedCheckpoint(TrackedCheckpoint): - """Tracked checkpoint that is not yet persisted to disk. - - This checkpoint class supports lazy writing. The checkpoint manager will - only call ``commit()`` if the checkpoint should be kept on disk. This class - will only then write checkpoint data to disk. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._data_to_commit = self.dir_or_data - self.dir_or_data = None - - @property - def committed(self) -> bool: - return not self._data_to_commit - - def commit(self, path: Optional[Path] = None): - if self.committed: - return - - assert path - - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(self._data_to_commit, f) - logger.debug(f"Checkpoint successfully written to: {path}") - - self.dir_or_data = path - self._data_to_commit = None - - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - if not self.committed: - return - return super().delete(delete_fn=delete_fn) - - @classmethod - def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): - new_checkpoint = cls( - dir_or_data=checkpoint.dir_or_data, - storage_mode=TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.id, - result=checkpoint.result, - node_ip=checkpoint.node_ip, - ) - return new_checkpoint + checkpoint = Checkpoint.from_directory(str(checkpoint_path)) + return checkpoint.to_dict() class CheckpointManager(CommonCheckpointManager): @@ -114,6 +63,8 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ + _persist_memory_checkpoints = True + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir @@ -166,23 +117,10 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(checkpoint) - - # Only process further if we consider keeping this checkpoint on disk - if self._checkpoint_strategy.num_to_keep != 0: - not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) - ) - self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - - self._latest_checkpoint_id += 1 - def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) @@ -267,12 +205,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint( - self, checkpoint: _NotYetPersistedCheckpoint - ): - assert isinstance(checkpoint, _NotYetPersistedCheckpoint) - assert not checkpoint.committed - + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index dc72e21a939f..cc017de709b9 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -15,6 +15,7 @@ from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV from ray.train.torch import TorchConfig from ray.train.tensorflow import TensorflowConfig + from ray.train.horovod import HorovodConfig from ray.train.callbacks.callback import TrainingCallback from ray.train.worker_group import WorkerGroup From 28d00eb5eaf24c6da67df109e6cd0b981f792ac9 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 14:59:48 +0100 Subject: [PATCH 50/81] Fix faulty rebase --- python/ray/train/trainer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index efb39e9c3535..798dd99454a0 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,16 +224,11 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) - # Todo (krfricke): Initialize checkpoint manager here with final values - # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager( - checkpoint_strategy=None, run_dir=None - ) + self.checkpoint_manager = TuneCheckpointManager() else: - self.checkpoint_manager = CheckpointManager( - checkpoint_strategy=None, run_dir=None - ) + self.checkpoint_manager = CheckpointManager() + self.checkpoint_manager.on_init() def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From fae9a54307e3c882d6b35f65ec13402e0fd5419d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 13 May 2022 11:43:37 +0100 Subject: [PATCH 51/81] [tune/train] Consolidate checkpoint manager 2: Ray Train --- python/ray/train/checkpoint.py | 81 +++++++++++++++++++++++++++++++--- python/ray/train/trainer.py | 11 +++-- 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 1bc2349ff5b7..ffbddb5bd4a8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray.ml import Checkpoint +from ray import cloudpickle from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, @@ -31,8 +31,59 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - checkpoint = Checkpoint.from_directory(str(checkpoint_path)) - return checkpoint.to_dict() + with checkpoint_path.open("rb") as f: + return cloudpickle.load(f) + + +class _NotYetPersistedCheckpoint(TrackedCheckpoint): + """Tracked checkpoint that is not yet persisted to disk. + + This checkpoint class supports lazy writing. The checkpoint manager will + only call ``commit()`` if the checkpoint should be kept on disk. This class + will only then write checkpoint data to disk. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._data_to_commit = self.dir_or_data + self.dir_or_data = None + + @property + def committed(self) -> bool: + return not self._data_to_commit + + def commit(self, path: Optional[Path] = None): + if self.committed: + return + + assert path + + # Get or create checkpoint dir. + path.parent.mkdir(parents=True, exist_ok=True) + # Write checkpoint to disk. + with path.open("wb") as f: + cloudpickle.dump(self._data_to_commit, f) + logger.debug(f"Checkpoint successfully written to: {path}") + + self.dir_or_data = path + self._data_to_commit = None + + def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): + if not self.committed: + return + return super().delete(delete_fn=delete_fn) + + @classmethod + def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): + new_checkpoint = cls( + dir_or_data=checkpoint.dir_or_data, + storage_mode=TrackedCheckpoint.PERSISTENT, + checkpoint_id=checkpoint.id, + result=checkpoint.result, + node_ip=checkpoint.node_ip, + ) + return new_checkpoint class CheckpointManager(CommonCheckpointManager): @@ -63,8 +114,6 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ - _persist_memory_checkpoints = True - def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir @@ -117,10 +166,23 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=TrackedCheckpoint.MEMORY, - metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, + result={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) + def register_checkpoint(self, checkpoint: TrackedCheckpoint): + # Always update the latest memory checkpoint + self._replace_latest_memory_checkpoint(checkpoint) + + # Only process further if we consider keeping this checkpoint on disk + if self._checkpoint_strategy.num_to_keep != 0: + not_yet_persisted_checkpoint = ( + _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) + ) + self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) + + self._latest_checkpoint_id += 1 + def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) @@ -205,7 +267,12 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): + def _decide_what_to_do_with_checkpoint( + self, checkpoint: _NotYetPersistedCheckpoint + ): + assert isinstance(checkpoint, _NotYetPersistedCheckpoint) + assert not checkpoint.committed + self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 798dd99454a0..efb39e9c3535 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -224,11 +224,16 @@ def __init__( self._backend_executor = ActorWrapper(backend_executor_actor) + # Todo (krfricke): Initialize checkpoint manager here with final values + # rather than in `on_training_start` if self._is_tune_enabled(): - self.checkpoint_manager = TuneCheckpointManager() + self.checkpoint_manager = TuneCheckpointManager( + checkpoint_strategy=None, run_dir=None + ) else: - self.checkpoint_manager = CheckpointManager() - self.checkpoint_manager.on_init() + self.checkpoint_manager = CheckpointManager( + checkpoint_strategy=None, run_dir=None + ) def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path: """Create logdir for the Trainer.""" From 35f7b081ec2be919a675e0e64dc6c77d1847af3a Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:11:54 +0100 Subject: [PATCH 52/81] Undo changes --- python/ray/tune/analysis/experiment_analysis.py | 2 +- python/ray/tune/checkpoint_manager.py | 2 +- python/ray/tune/trial.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index ed0255365040..5d1bd3621f1e 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -423,7 +423,7 @@ def get_trial_checkpoints_paths( # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [ - (c.value, unflattened_lookup(metric, c.metrics)) for c in checkpoints + (c.value, unflattened_lookup(metric, c.result)) for c in checkpoints ] else: raise ValueError("trial should be a string or a Trial instance.") diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 321f1a132d39..75cf4b8cb835 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -214,7 +214,7 @@ def best_checkpoints(self): return [queue_item.value for queue_item in checkpoints] def _priority(self, checkpoint): - result = flatten_dict(checkpoint.metrics) + result = flatten_dict(checkpoint.result) priority = result[self._checkpoint_score_attr] if self._checkpoint_score_desc: priority = -priority diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 466bbc606d0d..f793714262f2 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -660,7 +660,7 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): def on_restore(self): """Handles restoration completion.""" assert self.is_restoring - self.last_result = self.restoring_from.metrics + self.last_result = self.restoring_from.result self.restoring_from = None self.invalidate_json_state() From dd7f796da33a131945d6678fe7a02c96697a01c1 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:18:00 +0100 Subject: [PATCH 53/81] Restore common entrypoint --- .../ray/util/ml_utils/checkpoint_manager.py | 80 +++++++++++++++---- 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index f97068b58377..2334d69ad7cd 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,3 +1,4 @@ +import copy import gc import heapq import logging @@ -10,6 +11,7 @@ from typing import Optional, Dict, Union, Callable, Tuple, List, Any import ray +from ray.ml import Checkpoint from ray.tune.result import NODE_IP from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI @@ -29,12 +31,21 @@ class TrackedCheckpoint: order to add metadata (e.g. the result, or the node where it has been created) and for bookkeeping purposes. + The data can be an object, a checkpoint directory, or a future to either. Because + we can't know if it's data or a directory from a future, this class expects + a ``storage_mode`` that makes the data type explicit. + + The passed metrics can be used to compare performance of different checkpoints. + The ``checkpoint_id`` is passed as an alternative to be able to order + checkpoints in time. + Args: dir_or_data: Checkpoint directory, checkpoint data, or a future to either. storage_mode: Either MEMORY or PERSISTENT. - checkpoint_id: Checkpoint number. Usually this should be monotonically + checkpoint_id: Checkpoint number. Will be used to determine checkpoint order + if metrics are not available. Usually this should be monotonically increasing for each tracked checkpoint. - result: Observed metrics for this checkpoint. This is used to determine + metrics: Observed metrics for this checkpoint. This is used to determine the value of the ``checkpoint_score_attr``. node_ip: IP of the node where the checkpoint was generated. Defaults to the current node. @@ -48,16 +59,15 @@ def __init__( dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], storage_mode: str, checkpoint_id: Optional[int] = None, - result: Optional[Dict] = None, + metrics: Optional[Dict] = None, node_ip: Optional[str] = None, ): self.dir_or_data = dir_or_data self.id = checkpoint_id self.storage_mode = storage_mode - # Todo: What to do if result is a subset of dir_or_data (dict) - self.result = result or {} - self.node_ip = node_ip or self.result.get(NODE_IP, None) + self.metrics = metrics or {} + self.node_ip = node_ip or self.metrics.get(NODE_IP, None) def commit(self, path: Optional[Path] = None) -> None: """Commit checkpoint to disk, if needed. @@ -65,7 +75,20 @@ def commit(self, path: Optional[Path] = None) -> None: Args: path: Path to commit checkpoint to. """ - pass + if self.storage_mode == TrackedCheckpoint.MEMORY: + # Do not persist memory checkpoints + return + + if not isinstance(self.dir_or_data, dict): + # Only persist dictionaries + return + + if not path: + # If no path is given, skip + return + + checkpoint = Checkpoint.from_dict(self.dir_or_data) + self.dir_or_data = checkpoint.to_directory(str(path)) def delete( self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None @@ -84,7 +107,7 @@ def delete( def __repr__(self): if self.storage_mode == TrackedCheckpoint.MEMORY: - return f"" + return f"" return ( f" Date: Tue, 17 May 2022 14:40:35 +0100 Subject: [PATCH 54/81] Adjust to changes in base PR --- python/ray/train/checkpoint.py | 81 +++------------------------------- 1 file changed, 7 insertions(+), 74 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index ffbddb5bd4a8..1bc2349ff5b7 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray import cloudpickle +from ray.ml import Checkpoint from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, @@ -31,59 +31,8 @@ def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - with checkpoint_path.open("rb") as f: - return cloudpickle.load(f) - - -class _NotYetPersistedCheckpoint(TrackedCheckpoint): - """Tracked checkpoint that is not yet persisted to disk. - - This checkpoint class supports lazy writing. The checkpoint manager will - only call ``commit()`` if the checkpoint should be kept on disk. This class - will only then write checkpoint data to disk. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._data_to_commit = self.dir_or_data - self.dir_or_data = None - - @property - def committed(self) -> bool: - return not self._data_to_commit - - def commit(self, path: Optional[Path] = None): - if self.committed: - return - - assert path - - # Get or create checkpoint dir. - path.parent.mkdir(parents=True, exist_ok=True) - # Write checkpoint to disk. - with path.open("wb") as f: - cloudpickle.dump(self._data_to_commit, f) - logger.debug(f"Checkpoint successfully written to: {path}") - - self.dir_or_data = path - self._data_to_commit = None - - def delete(self, delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None): - if not self.committed: - return - return super().delete(delete_fn=delete_fn) - - @classmethod - def from_tracked_checkpoint(cls, checkpoint: TrackedCheckpoint): - new_checkpoint = cls( - dir_or_data=checkpoint.dir_or_data, - storage_mode=TrackedCheckpoint.PERSISTENT, - checkpoint_id=checkpoint.id, - result=checkpoint.result, - node_ip=checkpoint.node_ip, - ) - return new_checkpoint + checkpoint = Checkpoint.from_directory(str(checkpoint_path)) + return checkpoint.to_dict() class CheckpointManager(CommonCheckpointManager): @@ -114,6 +63,8 @@ class CheckpointManager(CommonCheckpointManager): checkpoint may not be saved to disk. """ + _persist_memory_checkpoints = True + def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): self.run_dir = run_dir @@ -166,23 +117,10 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=TrackedCheckpoint.MEMORY, - result={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) - def register_checkpoint(self, checkpoint: TrackedCheckpoint): - # Always update the latest memory checkpoint - self._replace_latest_memory_checkpoint(checkpoint) - - # Only process further if we consider keeping this checkpoint on disk - if self._checkpoint_strategy.num_to_keep != 0: - not_yet_persisted_checkpoint = ( - _NotYetPersistedCheckpoint.from_tracked_checkpoint(checkpoint) - ) - self._decide_what_to_do_with_checkpoint(not_yet_persisted_checkpoint) - - self._latest_checkpoint_id += 1 - def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) @@ -267,12 +205,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint( - self, checkpoint: _NotYetPersistedCheckpoint - ): - assert isinstance(checkpoint, _NotYetPersistedCheckpoint) - assert not checkpoint.committed - + def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): self.add_tune_checkpoint_id(checkpoint._data_to_commit) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: From 918c2141adab0af9a5bb80667a7d9d4cc1391aa0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:21:23 +0100 Subject: [PATCH 55/81] Do not persist memory checkpoints --- python/ray/tune/checkpoint_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 1e40a3a38b90..521eb8edb0d4 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -29,6 +29,7 @@ class CheckpointManager(CommonCheckpointManager): delete_fn: Function that deletes checkpoints. Must be idempotent. """ + _persist_memory_checkpoints = False def __init__( self, From 7cf2645ccdd2366b8a57dd6b8e4513028e7f6869 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:27:35 +0100 Subject: [PATCH 56/81] Use enum --- .../ray/util/ml_utils/checkpoint_manager.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 2334d69ad7cd..710fa0b724d2 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -1,4 +1,5 @@ import copy +import enum import gc import heapq import logging @@ -23,6 +24,11 @@ logger = logging.getLogger(__name__) +class CheckpointStorage(enum.Enum): + MEMORY = enum.auto() + PERSISTENT = enum.auto() + + @DeveloperAPI class TrackedCheckpoint: """Checkpoint tracked by a checkpoint manager. @@ -51,13 +57,10 @@ class TrackedCheckpoint: to the current node. """ - MEMORY = "memory" - PERSISTENT = "persistent" - def __init__( self, dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], - storage_mode: str, + storage_mode: CheckpointStorage, checkpoint_id: Optional[int] = None, metrics: Optional[Dict] = None, node_ip: Optional[str] = None, @@ -75,7 +78,7 @@ def commit(self, path: Optional[Path] = None) -> None: Args: path: Path to commit checkpoint to. """ - if self.storage_mode == TrackedCheckpoint.MEMORY: + if self.storage_mode == CheckpointStorage.MEMORY: # Do not persist memory checkpoints return @@ -106,7 +109,7 @@ def delete( logger.warning(f"Checkpoint deletion failed: {e}") def __repr__(self): - if self.storage_mode == TrackedCheckpoint.MEMORY: + if self.storage_mode == CheckpointStorage.MEMORY: return f"" return ( @@ -116,7 +119,7 @@ def __repr__(self): def _default_delete_fn(checkpoint: TrackedCheckpoint): - if checkpoint.storage_mode != TrackedCheckpoint.PERSISTENT: + if checkpoint.storage_mode != CheckpointStorage.PERSISTENT: return if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): @@ -262,12 +265,12 @@ def register_checkpoint(self, checkpoint: TrackedCheckpoint): """ checkpoint.id = checkpoint.id or self._latest_checkpoint_id - if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == CheckpointStorage.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) if self._persist_memory_checkpoints: persisted_checkpoint = copy.copy(checkpoint) - persisted_checkpoint.storage_mode = TrackedCheckpoint.PERSISTENT + persisted_checkpoint.storage_mode = CheckpointStorage.PERSISTENT else: persisted_checkpoint = None else: @@ -279,7 +282,7 @@ def register_checkpoint(self, checkpoint: TrackedCheckpoint): self._latest_checkpoint_id += 1 def _replace_latest_memory_checkpoint(self, memory_checkpoint: TrackedCheckpoint): - assert memory_checkpoint.storage_mode == TrackedCheckpoint.MEMORY + assert memory_checkpoint.storage_mode == CheckpointStorage.MEMORY self._latest_memory_checkpoint = memory_checkpoint # Avoid memory leaks on k8s pods gc.collect() @@ -346,7 +349,7 @@ def _get_checkpoint_score( ) def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): - assert checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT + assert checkpoint.storage_mode == CheckpointStorage.PERSISTENT checkpoint_score = self._get_checkpoint_score(checkpoint) wrapped_checkpoint = _HeapCheckpointWrapper( @@ -421,7 +424,7 @@ def __getstate__(self): state["_newest_memory_checkpoint"] = TrackedCheckpoint( dir_or_data=None, checkpoint_id=0, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, ) return state From 88f9fed53ec8f74e2ea4594c477c472d8862a8fa Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:28:30 +0100 Subject: [PATCH 57/81] Use enum --- python/ray/train/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 1bc2349ff5b7..a65d72269b14 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -15,7 +15,7 @@ from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, TrackedCheckpoint, - CheckpointStrategy, + CheckpointStrategy, CheckpointStorage, ) if TUNE_INSTALLED: @@ -116,7 +116,7 @@ def _process_checkpoint( tracked_checkpoint = TrackedCheckpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) From 59264efd31408ec1de2cd5524c550a6c1e400c91 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:31:55 +0100 Subject: [PATCH 58/81] Use enum --- python/ray/train/checkpoint.py | 3 +- python/ray/tune/checkpoint_manager.py | 14 ++++---- python/ray/tune/ray_trial_executor.py | 8 ++--- python/ray/tune/schedulers/pbt.py | 6 ++-- python/ray/tune/syncer.py | 4 +-- .../ray/tune/tests/test_checkpoint_manager.py | 36 ++++++++++--------- .../ray/tune/tests/test_ray_trial_executor.py | 6 ++-- python/ray/tune/tests/test_trial_runner_2.py | 4 +-- .../tune/tests/test_trial_runner_callbacks.py | 4 +-- python/ray/tune/tests/test_trial_scheduler.py | 10 +++--- .../tune/tests/test_trial_scheduler_pbt.py | 4 +-- .../test_trial_scheduler_resource_changing.py | 4 +-- python/ray/tune/trial.py | 6 ++-- python/ray/tune/trial_executor.py | 6 ++-- python/ray/tune/trial_runner.py | 6 ++-- 15 files changed, 64 insertions(+), 57 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index a65d72269b14..9fada6205acf 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -15,7 +15,8 @@ from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, TrackedCheckpoint, - CheckpointStrategy, CheckpointStorage, + CheckpointStrategy, + CheckpointStorage, ) if TUNE_INSTALLED: diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 521eb8edb0d4..b7a6cee89698 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -9,6 +9,7 @@ MAX, CheckpointManager as CommonCheckpointManager, TrackedCheckpoint, + CheckpointStorage, ) from ray.util.ml_utils.dict import flatten_dict from ray.util.ml_utils.util import is_nan @@ -29,6 +30,7 @@ class CheckpointManager(CommonCheckpointManager): delete_fn: Function that deletes checkpoints. Must be idempotent. """ + _persist_memory_checkpoints = False def __init__( @@ -64,10 +66,10 @@ def handle_checkpoint(self, checkpoint: TrackedCheckpoint): checkpoint.id = checkpoint.id or self._latest_checkpoint_id self._latest_checkpoint_id += 1 - if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == CheckpointStorage.MEMORY: self._replace_latest_memory_checkpoint(checkpoint) else: - assert checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT + assert checkpoint.storage_mode == CheckpointStorage.PERSISTENT assert ( self._checkpoint_strategy.num_to_keep is None or self._checkpoint_strategy.num_to_keep > 0 @@ -80,7 +82,7 @@ def on_checkpoint(self, checkpoint: TrackedCheckpoint): self.handle_checkpoint(checkpoint) def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): - assert persisted_checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT + assert persisted_checkpoint.storage_mode == CheckpointStorage.PERSISTENT super()._skip_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) # Ray Tune always keeps track of the latest persisted checkpoint. # Note that this checkpoint will be deleted once it is not the @@ -96,7 +98,7 @@ def newest_persistent_checkpoint(self): return self._latest_persisted_checkpoint or TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, ) @property @@ -113,7 +115,7 @@ def newest_memory_checkpoint(self): return self._latest_memory_checkpoint or TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, ) def best_checkpoints(self): @@ -136,7 +138,7 @@ def __getstate__(self): state = self.__dict__.copy() # Avoid serializing the memory checkpoint. state["_newest_memory_checkpoint"] = TrackedCheckpoint( - TrackedCheckpoint.MEMORY, None + CheckpointStorage.MEMORY, None ) # Avoid serializing lambda since it may capture cyclical dependencies. state.pop("_delete_fn") diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index e215d8233a7c..14519e6762ea 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -38,7 +38,7 @@ from ray.tune.utils.resource_updater import ResourceUpdater from ray.util import log_once from ray.util.annotations import DeveloperAPI -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage from ray.util.placement_group import remove_placement_group, PlacementGroup logger = logging.getLogger(__name__) @@ -672,7 +672,7 @@ def force_reconcilation_on_next_step_end(self) -> None: def save( self, trial: Trial, - storage: str = TrackedCheckpoint.PERSISTENT, + storage: str = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. @@ -690,7 +690,7 @@ def save( logger.debug(f"saving trial {trial}") result = result or trial.last_result with self._change_working_directory(trial): - if storage == TrackedCheckpoint.MEMORY: + if storage == CheckpointStorage.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = TrackedCheckpoint( dir_or_data=value, storage_mode=storage, metrics=result @@ -725,7 +725,7 @@ def restore(self, trial: Trial) -> None: ) value = checkpoint.dir_or_data node_ip = checkpoint.node_ip - if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == CheckpointStorage.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index ba1cda338296..2e199971559c 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -18,7 +18,7 @@ from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial from ray.util.debug import log_once -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import CheckpointStorage logger = logging.getLogger(__name__) @@ -529,7 +529,7 @@ def _checkpoint_or_exploit( state.last_checkpoint = trial.checkpoint else: state.last_checkpoint = trial_executor.save( - trial, TrackedCheckpoint.MEMORY, result=state.last_result + trial, CheckpointStorage.MEMORY, result=state.last_result ) self._num_checkpoints += 1 else: @@ -873,7 +873,7 @@ def on_trial_result( ) checkpoint = trial_runner.trial_executor.save( - trial, TrackedCheckpoint.MEMORY, result=result + trial, CheckpointStorage.MEMORY, result=result ) new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config) diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index e97116036b4f..e136d72fdd88 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -37,7 +37,7 @@ RemoteTaskClient, ) from ray.util.annotations import PublicAPI -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage if TYPE_CHECKING: from ray.tune.trial import Trial @@ -526,7 +526,7 @@ def _remove_trial_syncer(self, trial: "Trial"): self._syncers.pop(trial, None) def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: TrackedCheckpoint): - if checkpoint.storage_mode == TrackedCheckpoint.MEMORY: + if checkpoint.storage_mode == CheckpointStorage.MEMORY: return trial_syncer = self._get_trial_syncer(trial) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 14679640aa27..81196f344ce6 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -8,7 +8,11 @@ from ray.tune.result import TRAINING_ITERATION from ray.tune.checkpoint_manager import CheckpointManager -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, logger +from ray.util.ml_utils.checkpoint_manager import ( + TrackedCheckpoint, + logger, + CheckpointStorage, +) class CheckpointManagerTest(unittest.TestCase): @@ -27,13 +31,13 @@ def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = TrackedCheckpoint( dir_or_data={0}, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = TrackedCheckpoint( dir_or_data={1}, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(1, 1), ) checkpoint_manager.on_checkpoint(persistent_checkpoint) @@ -51,7 +55,7 @@ def testOnCheckpointOrdered(self): checkpoints = [ TrackedCheckpoint( dir_or_data={i}, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(3) @@ -83,7 +87,7 @@ def testOnCheckpointUnordered(self): checkpoints = [ TrackedCheckpoint( dir_or_data={i}, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(3, -1, -1) @@ -117,7 +121,7 @@ def testBestCheckpoints(self): checkpoints = [ TrackedCheckpoint( dir_or_data=i, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(8) @@ -142,14 +146,14 @@ def testBestCheckpointsWithNan(self): checkpoints = [ TrackedCheckpoint( dir_or_data=None, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ TrackedCheckpoint( dir_or_data=3, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 3), ) ] @@ -174,7 +178,7 @@ def testBestCheckpointsOnlyNan(self): checkpoints = [ TrackedCheckpoint( dir_or_data=i, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), ) for i in range(4) @@ -198,7 +202,7 @@ def testOnCheckpointUnavailableAttribute(self): no_attr_checkpoint = TrackedCheckpoint( dir_or_data=0, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics={}, ) @@ -214,12 +218,12 @@ def testOnMemoryCheckpoint(self): checkpoints = [ TrackedCheckpoint( dir_or_data=0, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), TrackedCheckpoint( dir_or_data=0, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), ] @@ -248,22 +252,22 @@ def testSameCheckpoint(self): checkpoints = [ TrackedCheckpoint( dir_or_data=tmpfiles[0], - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(5, 5), ), TrackedCheckpoint( dir_or_data=tmpfiles[1], - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(10, 10), ), TrackedCheckpoint( dir_or_data=tmpfiles[2], - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 0), ), TrackedCheckpoint( dir_or_data=tmpfiles[1], - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(20, 20), ), ] diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 12053447e2cf..66c0622918e0 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -24,7 +24,7 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory from unittest.mock import patch -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import CheckpointStorage class TrialExecutorInsufficientResourcesTest(unittest.TestCase): @@ -120,7 +120,7 @@ def _simulate_getting_result(self, trial): trial.update_last_result(training_result) def _simulate_saving(self, trial): - checkpoint = self.trial_executor.save(trial, TrackedCheckpoint.PERSISTENT) + checkpoint = self.trial_executor.save(trial, CheckpointStorage.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.dir_or_data, None) event = self.trial_executor.get_next_executor_event( @@ -189,7 +189,7 @@ def testSavePauseResumeErrorRestore(self): # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) - self.assertEqual(trial.checkpoint.storage_mode, TrackedCheckpoint.MEMORY) + self.assertEqual(trial.checkpoint.storage_mode, CheckpointStorage.MEMORY) # Resume self._simulate_starting_trial(trial) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index d0a5dc084ad5..9c043e8c23e9 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -18,7 +18,7 @@ from ray.tune.suggest import BasicVariantGenerator from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver from ray.tune.utils.trainable import TrainableUtil -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage def create_mock_components(): @@ -335,7 +335,7 @@ def write_checkpoint(trial: Trial, index: int): tune_cp = TrackedCheckpoint( dir_or_data=checkpoint_dir, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=result, ) trial.saving_to = tune_cp diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 4ae7a7da2a3c..5e505c57c463 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -25,7 +25,7 @@ from ray.tune import Callback from ray.tune.utils.callback import create_default_callbacks from ray.tune.experiment import Experiment -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage class TestCallback(Callback): @@ -152,7 +152,7 @@ def testCallbackSteps(self): # Just a placeholder object ref for cp.value. cp = TrackedCheckpoint( dir_or_data=ray.put(1), - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 0af6fa1f084f..7d62f5f755d7 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -33,7 +33,7 @@ from ray.tune.resources import Resources from ray.rllib import _register_all -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage _register_all() @@ -249,10 +249,10 @@ def stop_trial(self, trial, error=False, error_msg=None): def restore(self, trial, checkpoint=None, block=False): pass - def save(self, trial, type=TrackedCheckpoint.PERSISTENT, result=None): + def save(self, trial, type=CheckpointStorage.PERSISTENT, result=None): return TrackedCheckpoint( dir_or_data=trial.trainable_name, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, metrics=result, ) @@ -312,7 +312,7 @@ def get_live_trials(self): return {t for t in self.trials if t.status != Trial.TERMINATED} def _pause_trial(self, trial): - self.trial_executor.save(trial, TrackedCheckpoint.MEMORY, None) + self.trial_executor.save(trial, CheckpointStorage.MEMORY, None) trial.status = Trial.PAUSED def _launch_trial(self, trial): @@ -849,7 +849,7 @@ def on_checkpoint(self, checkpoint): def checkpoint(self): return TrackedCheckpoint( dir_or_data=self.trainable_name, - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics=None, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 542fef127e40..f7a7da2a6f4d 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -20,7 +20,7 @@ # Import psutil after ray so the packaged version is used. import psutil -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage MB = 1024 ** 2 @@ -442,7 +442,7 @@ class MockTrial(Trial): def checkpoint(self): return TrackedCheckpoint( dir_or_data="None", - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics={}, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index a0dfcc843ff5..4b0f2b35a0bc 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -8,7 +8,7 @@ DistributeResources, DistributeResourcesToTopJob, ) -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage class MockResourceUpdater: @@ -51,7 +51,7 @@ class MockTrial(Trial): def checkpoint(self): return TrackedCheckpoint( dir_or_data="None", - storage_mode=TrackedCheckpoint.MEMORY, + storage_mode=CheckpointStorage.MEMORY, metrics={}, ) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 30684269f0f0..a88e352dce40 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -40,7 +40,7 @@ from ray.tune.utils import date_str, flatten_dict from ray.util.annotations import DeveloperAPI from ray._private.utils import binary_to_hex, hex_to_binary -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage DEBUG_PRINT_INTERVAL = 5 logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ def __call__(self, checkpoint: TrackedCheckpoint): return if ( - checkpoint.storage_mode == TrackedCheckpoint.PERSISTENT + checkpoint.storage_mode == CheckpointStorage.PERSISTENT and checkpoint.dir_or_data ): checkpoint_path = checkpoint.dir_or_data @@ -469,7 +469,7 @@ def checkpoint(self): if checkpoint.dir_or_data is None: checkpoint = TrackedCheckpoint( dir_or_data=self.restore_path, - storage_mode=TrackedCheckpoint.PERSISTENT, + storage_mode=CheckpointStorage.PERSISTENT, ) return checkpoint diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 34bf96d7c5d6..a81b9ee997da 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -7,7 +7,7 @@ from ray.tune import TuneError from ray.util.annotations import DeveloperAPI from ray.tune.trial import Trial -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage logger = logging.getLogger(__name__) @@ -124,7 +124,7 @@ def pause_trial(self, trial: Trial) -> None: """ assert trial.status == Trial.RUNNING, trial.status try: - self.save(trial, TrackedCheckpoint.MEMORY) + self.save(trial, CheckpointStorage.MEMORY) self.stop_trial(trial) self.set_status(trial, Trial.PAUSED) except Exception: @@ -194,7 +194,7 @@ def restore(self, trial: Trial) -> None: def save( self, trial: Trial, - storage: str = TrackedCheckpoint.PERSISTENT, + storage: str = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> TrackedCheckpoint: """Saves training state of this trial to a checkpoint. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index bd8c09ce3437..39e9ddc42631 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -42,7 +42,7 @@ from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder from ray.tune.web_server import TuneServer from ray.util.debug import log_once -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import CheckpointStorage MAX_DEBUG_TRIALS = 20 @@ -1117,7 +1117,7 @@ def _process_trial_save( checkpoint=trial.saving_to, ) trial.on_checkpoint(trial.saving_to) - if trial.checkpoint.storage_mode != TrackedCheckpoint.MEMORY: + if trial.checkpoint.storage_mode != CheckpointStorage.MEMORY: self.trial_executor.mark_trial_to_checkpoint(trial) except Exception: logger.exception( @@ -1204,7 +1204,7 @@ def _checkpoint_trial_if_needed(self, trial, force=False): if trial.should_checkpoint() or force: # Save trial runtime if possible. if trial.runner: - self.trial_executor.save(trial, storage=TrackedCheckpoint.PERSISTENT) + self.trial_executor.save(trial, storage=CheckpointStorage.PERSISTENT) def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]): """Tries to recover trial. From 839c7ab709b69641eeaa5530f8db934aaebe0d36 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:40:29 +0100 Subject: [PATCH 59/81] Add tests --- python/ray/util/ml_utils/BUILD | 8 ++ .../ml_utils/tests/test_checkpoint_manager.py | 95 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 python/ray/util/ml_utils/tests/test_checkpoint_manager.py diff --git a/python/ray/util/ml_utils/BUILD b/python/ray/util/ml_utils/BUILD index 8f43babe6148..b5ec98ba3328 100644 --- a/python/ray/util/ml_utils/BUILD +++ b/python/ray/util/ml_utils/BUILD @@ -2,6 +2,14 @@ # Tests from the python/ray/util/ml_util/tests directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- +py_test( + name = "test_checkpoint_manager", + size = "small", + srcs = ["tests/test_checkpoint_manager.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_util_lib"] +) + py_test( name = "test_mlflow", size = "medium", diff --git a/python/ray/util/ml_utils/tests/test_checkpoint_manager.py b/python/ray/util/ml_utils/tests/test_checkpoint_manager.py new file mode 100644 index 000000000000..9aeb82229c43 --- /dev/null +++ b/python/ray/util/ml_utils/tests/test_checkpoint_manager.py @@ -0,0 +1,95 @@ +import pytest +from ray.util.ml_utils.checkpoint_manager import ( + CheckpointManager, + CheckpointStorage, + CheckpointStrategy, + TrackedCheckpoint, +) + + +def test_unlimited_persistent_checkpoints(): + cpm = CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 10 + + +def test_limited_persistent_checkpoints(): + cpm = CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=2)) + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 2 + + +def test_no_persistent_checkpoints(): + cpm = CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=0)) + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 0 + + +def test_dont_persist_memory_checkpoints(): + cpm = CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + cpm._persist_memory_checkpoints = False + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY) + ) + + assert len(cpm._top_persisted_checkpoints) == 0 + + +def test_persist_memory_checkpoints(): + cpm = CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + cpm._persist_memory_checkpoints = True + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY) + ) + + assert len(cpm._top_persisted_checkpoints) == 10 + + +def test_keep_best_checkpoints(): + cpm = CheckpointManager( + checkpoint_strategy=CheckpointStrategy( + num_to_keep=2, + checkpoint_score_attribute="metric", + checkpoint_score_order="min", + ) + ) + cpm._persist_memory_checkpoints = True + + for i in range(10): + cpm.register_checkpoint( + TrackedCheckpoint( + {"data": i}, + storage_mode=CheckpointStorage.MEMORY, + metrics={"metric": i}, + ) + ) + + # Sorted from worst (max) to best (min) + assert [ + cp.tracked_checkpoint.metrics["metric"] for cp in cpm._top_persisted_checkpoints + ] == [1, 0] + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) From 048e6677dd3f8418aef5c3c22f7cb6ac559d0e63 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 15:43:38 +0100 Subject: [PATCH 60/81] Re-order --- python/ray/util/ml_utils/checkpoint_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 710fa0b724d2..e23f80b7f940 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -82,14 +82,14 @@ def commit(self, path: Optional[Path] = None) -> None: # Do not persist memory checkpoints return - if not isinstance(self.dir_or_data, dict): - # Only persist dictionaries - return - if not path: # If no path is given, skip return + if not isinstance(self.dir_or_data, dict): + # Only persist dictionaries + return + checkpoint = Checkpoint.from_dict(self.dir_or_data) self.dir_or_data = checkpoint.to_directory(str(path)) From 4af1de74b1892eead5298143e886f32a7a678e3d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 17 May 2022 16:57:30 +0100 Subject: [PATCH 61/81] Checkpoints are now directories, not files --- python/ray/train/checkpoint.py | 5 +++-- python/ray/train/tests/test_trainer.py | 2 +- python/ray/train/tests/test_tune.py | 10 +++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index a65d72269b14..8fa47b983eea 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -15,7 +15,8 @@ from ray.util.ml_utils.checkpoint_manager import ( CheckpointManager as CommonCheckpointManager, TrackedCheckpoint, - CheckpointStrategy, CheckpointStorage, + CheckpointStrategy, + CheckpointStorage, ) if TUNE_INSTALLED: @@ -206,7 +207,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): - self.add_tune_checkpoint_id(checkpoint._data_to_commit) + self.add_tune_checkpoint_id(checkpoint.dir_or_data) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index cc017de709b9..2511c85b755e 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -495,7 +495,7 @@ def train_func(): if logdir is not None: assert trainer.logdir == Path(logdir).expanduser().resolve() assert trainer.latest_checkpoint_dir.is_dir() - assert trainer.best_checkpoint_path.is_file() + assert trainer.best_checkpoint_path.is_dir() assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}" assert trainer.best_checkpoint_path.parent.name == "checkpoints" assert trainer.best_checkpoint == trainer.latest_checkpoint diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index bee476100746..90566aaa39a0 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -4,6 +4,7 @@ import ray import ray.train as train from ray import tune, cloudpickle +from ray.ml import Checkpoint from ray.tune import TuneError from ray.train import Trainer from ray.train.backend import Backend, BackendConfig @@ -117,11 +118,10 @@ def train_func(): TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable).trials - checkpoint_file = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME) - assert os.path.exists(checkpoint_file) - with open(checkpoint_file, "rb") as f: - checkpoint = cloudpickle.load(f) - assert checkpoint["hello"] == "world" + checkpoint_path = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME) + assert os.path.exists(checkpoint_path) + checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() + assert checkpoint["hello"] == "world" def test_reuse_checkpoint(ray_start_2_cpus): From 47bccdd13f104f956156f4ce6ecf45b840d5f1f2 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 25 May 2022 15:41:10 +0200 Subject: [PATCH 62/81] Privatize _checkpoint.py --- .../train/{checkpoint.py => _checkpoint.py} | 22 +++++++++++-------- python/ray/train/trainer.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) rename python/ray/train/{checkpoint.py => _checkpoint.py} (92%) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/_checkpoint.py similarity index 92% rename from python/ray/train/checkpoint.py rename to python/ray/train/_checkpoint.py index 8fa47b983eea..e6b3121e62d8 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -13,8 +13,8 @@ from ray.train.session import TrainingResult from ray.train.utils import construct_path from ray.util.ml_utils.checkpoint_manager import ( - CheckpointManager as CommonCheckpointManager, - TrackedCheckpoint, + _CheckpointManager as CommonCheckpointManager, + _TrackedCheckpoint, CheckpointStrategy, CheckpointStorage, ) @@ -114,7 +114,7 @@ def _process_checkpoint( f"train.save_checkpoint." ) - tracked_checkpoint = TrackedCheckpoint( + tracked_checkpoint = _TrackedCheckpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=CheckpointStorage.MEMORY, @@ -124,7 +124,9 @@ def _process_checkpoint( def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" - checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + checkpoint_file = _construct_checkpoint_file_name( + self._latest_checkpoint_id + 1 + ) return self.latest_checkpoint_dir.joinpath(checkpoint_file) def on_start_training( @@ -158,14 +160,16 @@ def latest_checkpoint_dir(self) -> Optional[Path]: def latest_checkpoint_file_name(self) -> Optional[str]: """Filename to use for the latest checkpoint.""" if self._latest_checkpoint_id > 0: - return construct_checkpoint_file_name(self._latest_checkpoint_id) + return _construct_checkpoint_file_name(self._latest_checkpoint_id) else: return None @property def next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" - checkpoint_file = construct_checkpoint_file_name(self._latest_checkpoint_id + 1) + checkpoint_file = _construct_checkpoint_file_name( + self._latest_checkpoint_id + 1 + ) return self.latest_checkpoint_dir.joinpath(checkpoint_file) @property @@ -206,7 +210,7 @@ def add_tune_checkpoint_id(self, checkpoint: Dict): # resumed after failure or cancellation. checkpoint[TUNE_CHECKPOINT_ID] = self._latest_checkpoint_id - def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): + def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): self.add_tune_checkpoint_id(checkpoint.dir_or_data) # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: @@ -216,8 +220,8 @@ def _decide_what_to_do_with_checkpoint(self, checkpoint: TrackedCheckpoint): file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) checkpoint.commit(file_path) - return super()._decide_what_to_do_with_checkpoint(checkpoint) + return super()._process_persistent_checkpoint(checkpoint) -def construct_checkpoint_file_name(checkpoint_id: int) -> str: +def _construct_checkpoint_file_name(checkpoint_id: int) -> str: return f"checkpoint_{checkpoint_id:06d}" diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index efb39e9c3535..617c1d3feef9 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -22,7 +22,7 @@ construct_train_func, ActorWrapper, ) -from ray.train.checkpoint import ( +from ray.train._checkpoint import ( CheckpointStrategy, TuneCheckpointManager, CheckpointManager, From a0be76b072b1ac8a8bdb9625601193c7d8d377e3 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 25 May 2022 17:14:20 +0200 Subject: [PATCH 63/81] Privatize _checkpoint.py --- python/ray/ml/train/data_parallel_trainer.py | 2 +- python/ray/train/__init__.py | 2 +- python/ray/train/tests/test_trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index e07366cee246..d39cb44fbea7 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -24,7 +24,7 @@ from ray.ml.checkpoint import Checkpoint from ray.train import BackendConfig, TrainingIterator from ray.train.backend import BackendExecutor -from ray.train.checkpoint import TuneCheckpointManager +from ray.train._checkpoint import TuneCheckpointManager from ray.train.impl.dataset_spec import _RayDatasetSpec from ray.train.utils import construct_train_func from ray.util.annotations import DeveloperAPI diff --git a/python/ray/train/__init__.py b/python/ray/train/__init__.py index 21c80b3511bd..b5fa13c0ca37 100644 --- a/python/ray/train/__init__.py +++ b/python/ray/train/__init__.py @@ -1,6 +1,6 @@ from ray.train.backend import BackendConfig from ray.train.callbacks import TrainingCallback -from ray.train.checkpoint import CheckpointStrategy +from ray.train._checkpoint import CheckpointStrategy from ray.train.session import ( get_dataset_shard, local_rank, diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index 2511c85b755e..f6135e6c52da 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -531,7 +531,7 @@ def train_func(): if logdir is not None: assert trainer.logdir == Path(logdir).expanduser().resolve() assert trainer.latest_checkpoint_dir.is_dir() - assert trainer.best_checkpoint_path.is_file() + assert trainer.best_checkpoint_path.is_dir() assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}" assert trainer.latest_checkpoint["loss"] == 5 assert trainer.best_checkpoint["loss"] == 3 From 7ec26caac8cb0d50b699f48407f462f393c67ffd Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 25 May 2022 17:35:29 +0200 Subject: [PATCH 64/81] Rename variables --- python/ray/train/_checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index e6b3121e62d8..008c8cf80568 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -124,10 +124,10 @@ def _process_checkpoint( def _get_next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" - checkpoint_file = _construct_checkpoint_file_name( + checkpoint_path = _construct_checkpoint_path_name( self._latest_checkpoint_id + 1 ) - return self.latest_checkpoint_dir.joinpath(checkpoint_file) + return self.latest_checkpoint_dir.joinpath(checkpoint_path) def on_start_training( self, @@ -160,14 +160,14 @@ def latest_checkpoint_dir(self) -> Optional[Path]: def latest_checkpoint_file_name(self) -> Optional[str]: """Filename to use for the latest checkpoint.""" if self._latest_checkpoint_id > 0: - return _construct_checkpoint_file_name(self._latest_checkpoint_id) + return _construct_checkpoint_path_name(self._latest_checkpoint_id) else: return None @property def next_checkpoint_path(self) -> Optional[Path]: """Path to the next checkpoint to persist.""" - checkpoint_file = _construct_checkpoint_file_name( + checkpoint_file = _construct_checkpoint_path_name( self._latest_checkpoint_id + 1 ) return self.latest_checkpoint_dir.joinpath(checkpoint_file) @@ -223,5 +223,5 @@ def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): return super()._process_persistent_checkpoint(checkpoint) -def _construct_checkpoint_file_name(checkpoint_id: int) -> str: +def _construct_checkpoint_path_name(checkpoint_id: int) -> str: return f"checkpoint_{checkpoint_id:06d}" From d1c5ddb23544a6eb60e162320d75e09001329225 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 26 May 2022 17:25:47 +0200 Subject: [PATCH 65/81] Optional init parameters --- python/ray/train/_checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index 008c8cf80568..928494196910 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -66,7 +66,11 @@ class CheckpointManager(CommonCheckpointManager): _persist_memory_checkpoints = True - def __init__(self, run_dir: Path, checkpoint_strategy: CheckpointStrategy): + def __init__( + self, + run_dir: Optional[Path] = None, + checkpoint_strategy: Optional[CheckpointStrategy] = None, + ): self.run_dir = run_dir super().__init__(checkpoint_strategy=checkpoint_strategy) From 5eadbcd6c1151f3e5b75237ffbbf6b8951e2a363 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 27 May 2022 12:37:12 +0200 Subject: [PATCH 66/81] Update DP trainer / HF trainer --- python/ray/ml/train/data_parallel_trainer.py | 27 +++--- .../huggingface/huggingface_trainer.py | 85 +++++++++++-------- 2 files changed, 65 insertions(+), 47 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index d39cb44fbea7..465ff5e5cecb 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -28,6 +28,7 @@ from ray.train.impl.dataset_spec import _RayDatasetSpec from ray.train.utils import construct_train_func from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint if TYPE_CHECKING: from ray.data import Dataset @@ -37,20 +38,22 @@ # TODO(team-ml): Refactor checkpoint management along with Tune. class _DataParallelCheckpointManager(TuneCheckpointManager): - def on_init(self, preprocessor: Preprocessor): + def __init__( + self, + preprocessor: Preprocessor, + run_dir: Optional[Path] = None, + checkpoint_strategy: Optional[CheckpointStrategy] = None, + ): self.preprocessor = preprocessor - super(_DataParallelCheckpointManager, self).on_init() - - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) - - # Add the preprocessor to the checkpoint. - checkpoint[PREPROCESSOR_KEY] = self.preprocessor + super(_DataParallelCheckpointManager, self).__init__( + run_dir=run_dir, checkpoint_strategy=checkpoint_strategy + ) - checkpoint_obj = Checkpoint.from_dict(checkpoint) - # If inside a Tune Trainable, then checkpoint with Tune. - with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: - checkpoint_obj.to_directory(path=checkpoint_dir) + def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): + checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor + super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint( + checkpoint=checkpoint + ) @property def latest_checkpoint_dir(self) -> Optional[Path]: diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 596900254e03..46cd3b54d1d0 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -11,16 +11,16 @@ import transformers import transformers.modeling_utils import transformers.trainer +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME import transformers.training_args from torch.utils.data import Dataset as TorchDataset from ray import train -from ray import tune from ray.util import PublicAPI, get_node_ip_address from ray.ml.checkpoint import Checkpoint from ray.ml.config import RunConfig, ScalingConfig -from ray.ml.constants import EVALUATION_DATASET_KEY, TRAIN_DATASET_KEY +from ray.ml.constants import EVALUATION_DATASET_KEY, TRAIN_DATASET_KEY, PREPROCESSOR_KEY from ray.ml.preprocessor import Preprocessor from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.trainer import GenDataset @@ -53,39 +53,54 @@ # TODO(ml-team): Make dir syncing checkpoint logic generic. -# The checkpoint is turned into a dict with node ip & path -# in HuggingFaceTrainer.as_trainable -# TODO(team-ml): Refactor checkpoint management along with Tune. +class _SyncedTrackedCheckpoint(_TrackedCheckpoint): + def commit(self, path: Optional[Path] = None) -> None: + if ( + self.storage_mode == CheckpointStorage.MEMORY + or not path + or not isinstance(self.dir_or_data, dict) + ): + return + + source_ip = self.dir_or_data[NODE_IP_KEY] + source_path = self.dir_or_data[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + + if source_ip == target_ip: + # Move contents of source_path, but not source_path + # itself. shutil.move is already recursive. + for inner in Path(source_path).iterdir(): + shutil.move(str(path.absolute()), inner) + shutil.rmtree(source_path, ignore_errors=True) + else: + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=str(path), + return_futures=False, + max_size_bytes=None, + ) + delete_on_node(node_ip=source_ip, path=source_path) + save_preprocessor_to_dir(self.dir_or_data.pop(PREPROCESSOR_KEY, None), path) + # add tune checkpoint id + with open(path.joinpath(TUNE_CHECKPOINT_ID), "w") as f: + f.write(str(self.id)) + + class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): - """As _DataParallelCheckpointManager, but syncs the dir instead of serializing.""" - - def write_checkpoint(self, checkpoint: Dict): - # If inside a Tune Trainable, then checkpoint with Tune. - with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: - source_ip = checkpoint[NODE_IP_KEY] - source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] - target_ip = get_node_ip_address() - if source_ip == target_ip: - # Move contents of source_path, but not source_path - # itself. shutil.move is already recursive. - for path in Path(source_path).iterdir(): - shutil.move(str(path.absolute()), checkpoint_dir) - shutil.rmtree(source_path, ignore_errors=True) - else: - sync_dir_between_nodes( - source_ip=source_ip, - source_path=source_path, - target_ip=target_ip, - target_path=checkpoint_dir, - return_futures=False, - max_size_bytes=None, - ) - delete_on_node(node_ip=source_ip, path=source_path) - checkpoint_dir = Path(checkpoint_dir) - save_preprocessor_to_dir(self.preprocessor, checkpoint_dir) - # add tune checkpoint id - with open(checkpoint_dir.joinpath(TUNE_CHECKPOINT_ID), "w") as f: - f.write(str(self._latest_checkpoint_id)) + def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): + sync_checkpoint = _SyncedTrackedCheckpoint( + dir_or_data=checkpoint.dir_or_data, + storage_mode=checkpoint.storage_mode, + checkpoint_id=checkpoint.id, + metrics=checkpoint.metrics, + node_ip=checkpoint.node_ip, + ) + + super( + _DataParallelSyncingCheckpointManager, self + )._process_persistent_checkpoint(checkpoint=sync_checkpoint) @PublicAPI(stability="alpha") @@ -328,7 +343,7 @@ def _convert_directory_checkpoint_to_sync_if_needed( ) -> Checkpoint: """Replace the directory checkpoint with a node ip & path dict checkpoint. - This dict checkpoint will be used used to sync the directory. + This dict checkpoint will be used to sync the directory. If we were to use a directory checkpoint directly, it would get deepcopied & serialized unnecessarily.""" with checkpoint.as_directory() as checkpoint_path: From 58f6ee81de7582d2262708d8296600e2953fa09e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 10:57:13 +0100 Subject: [PATCH 67/81] lint --- python/ray/train/_checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index 928494196910..e43ef042c71b 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -49,18 +49,18 @@ class CheckpointManager(CommonCheckpointManager): checkpoint_ Attributes: - latest_checkpoint_dir (Optional[Path]): Path to the file directory for + latest_checkpoint_dir: Path to the file directory for the checkpoints from the latest run. Configured through ``start_training``. - latest_checkpoint_filename (Optional[str]): Filename for the latest + latest_checkpoint_filename: Filename for the latest checkpoint. - next_checkpoint_path (Optional[Path]): Path to the next checkpoint to + next_checkpoint_path: Path to the next checkpoint to persist from the latest run. - best_checkpoint_path (Optional[Path]): Path to the best persisted + best_checkpoint_path: Path to the best persisted checkpoint from the latest run. - latest_checkpoint_id (Optional[int]): The id of the most recently + latest_checkpoint_id: The id of the most recently saved checkpoint. - latest_checkpoint (Optional[Dict]): The latest saved checkpoint. This + latest_checkpoint: The latest saved checkpoint. This checkpoint may not be saved to disk. """ From 5a8894877b4a37a2c486c41840c4017813ba3236 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 11:04:37 +0100 Subject: [PATCH 68/81] Fix tune tests --- python/ray/ml/train/data_parallel_trainer.py | 5 ++-- python/ray/train/_checkpoint.py | 6 +---- python/ray/train/constants.py | 3 --- python/ray/train/tests/test_tune.py | 24 ++++++++------------ python/ray/train/trainer.py | 8 +------ 5 files changed, 14 insertions(+), 32 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 465ff5e5cecb..419bbce48ea7 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -327,8 +327,9 @@ def training_loop(self) -> None: max_retries=0, ) - checkpoint_manager = self._checkpoint_manager_cls() - checkpoint_manager.on_init(preprocessor=self.preprocessor) + checkpoint_manager = self._checkpoint_manager_cls( + preprocessor=self.preprocessor + ) # Start the remote actors. backend_executor.start(initialization_hook=None) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index e43ef042c71b..d95bbeaf0988 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -6,7 +6,6 @@ from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, - TUNE_CHECKPOINT_FILE_NAME, TUNE_CHECKPOINT_ID, TUNE_INSTALLED, ) @@ -219,10 +218,7 @@ def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: path = Path(checkpoint_dir) - # Use a standard file name so that we know which file to load - # the checkpoint from. - file_path = path.joinpath(TUNE_CHECKPOINT_FILE_NAME) - checkpoint.commit(file_path) + checkpoint.commit(path) return super()._process_persistent_checkpoint(checkpoint) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 8dc5abb5960d..533de9eaa3aa 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -33,9 +33,6 @@ # Default directory where all Train logs, checkpoints, etc. will be stored. DEFAULT_RESULTS_DIR = Path("~/ray_results").expanduser() -# File name to use for checkpoints saved with Tune. -TUNE_CHECKPOINT_FILE_NAME = "checkpoint" - # The name of the subdirectory inside the trainer run_dir to store checkpoints. TRAIN_CHECKPOINT_SUBDIR = "checkpoints" diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index 90566aaa39a0..85a295e2db7b 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -8,7 +8,6 @@ from ray.tune import TuneError from ray.train import Trainer from ray.train.backend import Backend, BackendConfig -from ray.train.constants import TUNE_CHECKPOINT_FILE_NAME from ray.train.examples.tensorflow_mnist_example import ( train_func as tensorflow_mnist_train_func, ) @@ -118,7 +117,7 @@ def train_func(): TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable).trials - checkpoint_path = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME) + checkpoint_path = trial.checkpoint.value assert os.path.exists(checkpoint_path) checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() assert checkpoint["hello"] == "world" @@ -139,13 +138,10 @@ def train_func(config): TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable, config={"max_iter": 5}).trials - last_ckpt = trial.checkpoint.value - checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) - assert os.path.exists(checkpoint_file) - with open(checkpoint_file, "rb") as f: - checkpoint = cloudpickle.load(f) - assert checkpoint["iter"] == 4 - analysis = tune.run(TestTrainable, config={"max_iter": 10}, restore=last_ckpt) + checkpoint_path = trial.checkpoint.value + checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() + assert checkpoint["iter"] == 4 + analysis = tune.run(TestTrainable, config={"max_iter": 10}, restore=checkpoint_path) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5 @@ -168,12 +164,10 @@ def train_func(): TestTrainable = trainer.to_tune_trainable(train_func) analysis = tune.run(TestTrainable, max_failures=3) - last_ckpt = analysis.trials[0].checkpoint.value - checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) - assert os.path.exists(checkpoint_file) - with open(checkpoint_file, "rb") as f: - checkpoint = cloudpickle.load(f) - assert checkpoint["iter"] == 3 + checkpoint_path = analysis.trials[0].checkpoint.value + checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() + assert checkpoint["iter"] == 3 + trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 4 diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 7693658a6af3..220d55663b68 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -31,7 +31,6 @@ from ray.train.constants import ( TUNE_INSTALLED, DEFAULT_RESULTS_DIR, - TUNE_CHECKPOINT_FILE_NAME, ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, @@ -881,13 +880,8 @@ def tune_function(config, checkpoint_dir=None): trainer.start() - if checkpoint_dir is not None: - checkpoint_path = os.path.join(checkpoint_dir, TUNE_CHECKPOINT_FILE_NAME) - else: - checkpoint_path = None - iterator = trainer.run_iterator( - train_func, config, dataset=dataset, checkpoint=checkpoint_path + train_func, config, dataset=dataset, checkpoint=checkpoint_dir ) for results in iterator: From 23948ff85944bb40127e6044ed1efa142fb8c27f Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 11:30:23 +0100 Subject: [PATCH 69/81] Fix some merge issues --- python/ray/train/tests/test_tune.py | 2 +- python/ray/tune/callback.py | 4 +- python/ray/tune/checkpoint_manager.py | 33 +++++---------- python/ray/tune/ray_trial_executor.py | 8 ++-- python/ray/tune/syncer.py | 6 +-- .../ray/tune/tests/test_checkpoint_manager.py | 40 +++++++++---------- python/ray/tune/tests/test_trial_runner_2.py | 4 +- .../tune/tests/test_trial_runner_callbacks.py | 4 +- python/ray/tune/tests/test_trial_scheduler.py | 6 +-- .../tune/tests/test_trial_scheduler_pbt.py | 4 +- .../test_trial_scheduler_resource_changing.py | 4 +- python/ray/tune/trial.py | 8 ++-- python/ray/tune/trial_executor.py | 4 +- .../ray/util/ml_utils/checkpoint_manager.py | 6 +-- 14 files changed, 60 insertions(+), 73 deletions(-) diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index 43f6884ec54a..3e4a8e5fc9f4 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -3,7 +3,7 @@ import pytest import ray import ray.train as train -from ray import tune, cloudpickle +from ray import tune from ray.ml import Checkpoint from ray.tune import TuneError from ray.train import Trainer diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index 7ad486116e22..66cd78adb72d 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -3,7 +3,7 @@ import warnings from ray.util.annotations import PublicAPI, DeveloperAPI -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -245,7 +245,7 @@ def on_checkpoint( iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: TrackedCheckpoint, + checkpoint: _TrackedCheckpoint, **info, ): """Called after a trial saved a checkpoint with Tune. diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 3eb34552f530..64b68a7fb416 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -7,12 +7,10 @@ CheckpointStrategy, MIN, MAX, - CheckpointManager as CommonCheckpointManager, - TrackedCheckpoint, + _CheckpointManager as CommonCheckpointManager, + _TrackedCheckpoint, CheckpointStorage, ) -from ray.util.ml_utils.dict import flatten_dict -from ray.util.ml_utils.util import is_nan logger = logging.getLogger(__name__) @@ -37,7 +35,7 @@ def __init__( self, keep_checkpoints_num: int, checkpoint_score_attr: Optional[str], - delete_fn: Optional[Callable[["TrackedCheckpoint"], None]] = None, + delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): if keep_checkpoints_num == 0: raise RuntimeError( @@ -61,7 +59,7 @@ def __init__( super().__init__(checkpoint_strategy=checkpoint_strategy, delete_fn=delete_fn) - def handle_checkpoint(self, checkpoint: TrackedCheckpoint): + def handle_checkpoint(self, checkpoint: _TrackedCheckpoint): # Set checkpoint ID checkpoint.id = checkpoint.id or self._latest_checkpoint_id self._latest_checkpoint_id += 1 @@ -74,14 +72,14 @@ def handle_checkpoint(self, checkpoint: TrackedCheckpoint): self._checkpoint_strategy.num_to_keep is None or self._checkpoint_strategy.num_to_keep > 0 ) - self._decide_what_to_do_with_checkpoint(checkpoint) + self._process_persistent_checkpoint(checkpoint) - def on_checkpoint(self, checkpoint: TrackedCheckpoint): + def on_checkpoint(self, checkpoint: _TrackedCheckpoint): """Ray Tune's entrypoint""" # Todo (krfricke): Replace with handle_checkpoint. self.handle_checkpoint(checkpoint) - def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): + def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): assert persisted_checkpoint.storage_mode == CheckpointStorage.PERSISTENT super()._skip_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) # Ray Tune always keeps track of the latest persisted checkpoint. @@ -95,7 +93,7 @@ def _skip_persisted_checkpoint(self, persisted_checkpoint: TrackedCheckpoint): @property def newest_persistent_checkpoint(self): - return self._latest_persisted_checkpoint or TrackedCheckpoint( + return self._latest_persisted_checkpoint or _TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, storage_mode=CheckpointStorage.PERSISTENT, @@ -112,7 +110,7 @@ def newest_checkpoint(self): @property def newest_memory_checkpoint(self): - return self._latest_memory_checkpoint or TrackedCheckpoint( + return self._latest_memory_checkpoint or _TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, storage_mode=CheckpointStorage.MEMORY, @@ -123,21 +121,10 @@ def best_checkpoints(self): checkpoints = sorted(self._top_persisted_checkpoints, key=lambda c: c.priority) return [wrapped.tracked_checkpoint for wrapped in checkpoints] - def _priority(self, checkpoint): - result = flatten_dict(checkpoint.result) - priority = result[self._checkpoint_score_attr] - if self._checkpoint_score_desc: - priority = -priority - return ( - not is_nan(priority), - priority if not is_nan(priority) else 0, - checkpoint.order, - ) - def __getstate__(self): state = self.__dict__.copy() # Avoid serializing the memory checkpoint. - state["_newest_memory_checkpoint"] = TrackedCheckpoint( + state["_newest_memory_checkpoint"] = _TrackedCheckpoint( CheckpointStorage.MEMORY, None ) # Avoid serializing lambda since it may capture cyclical dependencies. diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index d6f93e33f298..a52bf56e76ae 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -38,7 +38,7 @@ from ray.tune.utils.resource_updater import _ResourceUpdater from ray.util import log_once from ray.util.annotations import DeveloperAPI -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage from ray.util.placement_group import remove_placement_group, PlacementGroup logger = logging.getLogger(__name__) @@ -679,7 +679,7 @@ def save( trial: Trial, storage: str = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, - ) -> TrackedCheckpoint: + ) -> _TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. Args: @@ -697,13 +697,13 @@ def save( with self._change_working_directory(trial): if storage == CheckpointStorage.MEMORY: value = trial.runner.save_to_object.remote() - checkpoint = TrackedCheckpoint( + checkpoint = _TrackedCheckpoint( dir_or_data=value, storage_mode=storage, metrics=result ) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() - checkpoint = TrackedCheckpoint( + checkpoint = _TrackedCheckpoint( dir_or_data=value, storage_mode=storage, metrics=result ) trial.saving_to = checkpoint diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index 8f49fc83042f..66f22df79547 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -37,7 +37,7 @@ RemoteTaskClient, ) from ray.util.annotations import PublicAPI, DeveloperAPI -from ray.util.ml_utils.checkpoint_manager import CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint if TYPE_CHECKING: from ray.tune.trial import Trial @@ -531,7 +531,7 @@ def _create_trial_syncer(self, trial: "Trial"): def _remove_trial_syncer(self, trial: "Trial"): self._syncers.pop(trial, None) - def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: TrackedCheckpoint): + def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint): if checkpoint.storage_mode == CheckpointStorage.MEMORY: return @@ -623,7 +623,7 @@ def on_checkpoint( iteration: int, trials: List["Trial"], trial: "Trial", - checkpoint: TrackedCheckpoint, + checkpoint: _TrackedCheckpoint, **info, ): self._sync_trial_checkpoint(trial, checkpoint) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 7e44ccf85fec..12df79fcb03b 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -9,7 +9,7 @@ from ray.tune.result import TRAINING_ITERATION from ray.tune.checkpoint_manager import _CheckpointManager from ray.util.ml_utils.checkpoint_manager import ( - TrackedCheckpoint, + _TrackedCheckpoint, logger, CheckpointStorage, ) @@ -29,14 +29,14 @@ def checkpoint_manager(self, keep_checkpoints_num): def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - memory_checkpoint = TrackedCheckpoint( - dir_or_data={0}, + memory_checkpoint = _TrackedCheckpoint( + dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) - persistent_checkpoint = TrackedCheckpoint( - dir_or_data={1}, + persistent_checkpoint = _TrackedCheckpoint( + dir_or_data={"a": 1}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(1, 1), ) @@ -53,7 +53,7 @@ def testOnCheckpointOrdered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data={i}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), @@ -85,7 +85,7 @@ def testOnCheckpointUnordered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data={i}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), @@ -119,7 +119,7 @@ def testBestCheckpoints(self): """ keep_checkpoints_num = 4 checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=i, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), @@ -144,14 +144,14 @@ def testBestCheckpointsWithNan(self): """ keep_checkpoints_num = 2 checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=None, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=3, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 3), @@ -176,7 +176,7 @@ def testBestCheckpointsOnlyNan(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=i, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), @@ -200,7 +200,7 @@ def testOnCheckpointUnavailableAttribute(self): """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - no_attr_checkpoint = TrackedCheckpoint( + no_attr_checkpoint = _TrackedCheckpoint( dir_or_data=0, storage_mode=CheckpointStorage.PERSISTENT, metrics={}, @@ -216,13 +216,13 @@ def testOnCheckpointUnavailableAttribute(self): def testOnMemoryCheckpoint(self): checkpoints = [ - TrackedCheckpoint( - dir_or_data=0, + _TrackedCheckpoint( + dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), - TrackedCheckpoint( - dir_or_data=0, + _TrackedCheckpoint( + dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), @@ -250,22 +250,22 @@ def testSameCheckpoint(self): tmpfiles.append(tmpfile) checkpoints = [ - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=tmpfiles[0], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(5, 5), ), - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(10, 10), ), - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=tmpfiles[2], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 0), ), - TrackedCheckpoint( + _TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(20, 20), diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 9c043e8c23e9..ead0717a4a72 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -18,7 +18,7 @@ from ray.tune.suggest import BasicVariantGenerator from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver from ray.tune.utils.trainable import TrainableUtil -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage def create_mock_components(): @@ -333,7 +333,7 @@ def write_checkpoint(trial: Trial, index: int): with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) - tune_cp = TrackedCheckpoint( + tune_cp = _TrackedCheckpoint( dir_or_data=checkpoint_dir, storage_mode=CheckpointStorage.PERSISTENT, metrics=result, diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 14c163c6f4a0..c73631dcf6f4 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -25,7 +25,7 @@ from ray.tune import Callback from ray.tune.utils.callback import create_default_callbacks from ray.tune.experiment import Experiment -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage class TestCallback(Callback): @@ -150,7 +150,7 @@ def testCallbackSteps(self): self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. - cp = TrackedCheckpoint( + cp = _TrackedCheckpoint( dir_or_data=ray.put(1), storage_mode=CheckpointStorage.PERSISTENT, metrics={TRAINING_ITERATION: 0}, diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 08a2dd3b8823..4af9b31f9459 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -33,7 +33,7 @@ from ray.tune.resources import Resources from ray.rllib import _register_all -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage _register_all() @@ -250,7 +250,7 @@ def restore(self, trial, checkpoint=None, block=False): pass def save(self, trial, type=CheckpointStorage.PERSISTENT, result=None): - return TrackedCheckpoint( + return _TrackedCheckpoint( dir_or_data=trial.trainable_name, storage_mode=CheckpointStorage.PERSISTENT, metrics=result, @@ -847,7 +847,7 @@ def on_checkpoint(self, checkpoint): @property def checkpoint(self): - return TrackedCheckpoint( + return _TrackedCheckpoint( dir_or_data=self.trainable_name, storage_mode=CheckpointStorage.MEMORY, metrics=None, diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index f7a7da2a6f4d..14417d3731f2 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -20,7 +20,7 @@ # Import psutil after ray so the packaged version is used. import psutil -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage MB = 1024 ** 2 @@ -440,7 +440,7 @@ def testBurnInPeriod(self): class MockTrial(Trial): @property def checkpoint(self): - return TrackedCheckpoint( + return _TrackedCheckpoint( dir_or_data="None", storage_mode=CheckpointStorage.MEMORY, metrics={}, diff --git a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py index d56bfd4ab717..0f9b58c58443 100644 --- a/python/ray/tune/tests/test_trial_scheduler_resource_changing.py +++ b/python/ray/tune/tests/test_trial_scheduler_resource_changing.py @@ -8,7 +8,7 @@ DistributeResources, DistributeResourcesToTopJob, ) -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage class MockResourceUpdater: @@ -49,7 +49,7 @@ def get_trials(self): class MockTrial(Trial): @property def checkpoint(self): - return TrackedCheckpoint( + return _TrackedCheckpoint( dir_or_data="None", storage_mode=CheckpointStorage.MEMORY, metrics={}, diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index fa486f38d19e..815abc807cf4 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -40,7 +40,7 @@ from ray.tune.utils import date_str, flatten_dict from ray.util.annotations import DeveloperAPI from ray._private.utils import binary_to_hex, hex_to_binary -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage DEBUG_PRINT_INTERVAL = 5 logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ def __init__(self, trial_id, runner): self.trial_id = trial_id self.runner = runner - def __call__(self, checkpoint: TrackedCheckpoint): + def __call__(self, checkpoint: _TrackedCheckpoint): """Requests checkpoint deletion asynchronously. Args: @@ -467,7 +467,7 @@ def checkpoint(self): else: checkpoint = self.checkpoint_manager.newest_checkpoint if checkpoint.dir_or_data is None: - checkpoint = TrackedCheckpoint( + checkpoint = _TrackedCheckpoint( dir_or_data=self.restore_path, storage_mode=CheckpointStorage.PERSISTENT, ) @@ -655,7 +655,7 @@ def clear_checkpoint(self): self.restoring_from = None self.invalidate_json_state() - def on_checkpoint(self, checkpoint: TrackedCheckpoint): + def on_checkpoint(self, checkpoint: _TrackedCheckpoint): """Hook for handling checkpoints taken by the Trainable. Args: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index a81b9ee997da..3870d22970e7 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -7,7 +7,7 @@ from ray.tune import TuneError from ray.util.annotations import DeveloperAPI from ray.tune.trial import Trial -from ray.util.ml_utils.checkpoint_manager import TrackedCheckpoint, CheckpointStorage +from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage logger = logging.getLogger(__name__) @@ -196,7 +196,7 @@ def save( trial: Trial, storage: str = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, - ) -> TrackedCheckpoint: + ) -> _TrackedCheckpoint: """Saves training state of this trial to a checkpoint. If result is None, this trial's last result will be used. diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index 5926a1e3fbb2..aae410c259d0 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -118,10 +118,10 @@ def delete( def __repr__(self): if self.storage_mode == CheckpointStorage.MEMORY: - return f"" + return f"<_TrackedCheckpoint storage='MEMORY' result={self.metrics}>" return ( - f"" ) @@ -210,7 +210,7 @@ class _CheckpointManager: keeps a configured number of checkpoints according to specified metrics. The manager supports lazy data writing by utilizing the - ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + ``_TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint should be persisted to disk. Args: From 8c934596fc7697a01b7a99ba87e4789def12f4e8 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 14:23:08 +0100 Subject: [PATCH 70/81] Tune checkpoint manager --- python/ray/ml/train/data_parallel_trainer.py | 4 ---- python/ray/train/_checkpoint.py | 13 ++++++++++++- python/ray/train/tests/test_trainer.py | 2 +- python/ray/train/tests/test_tune.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 419bbce48ea7..68d0b73d0c52 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -55,10 +55,6 @@ def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): checkpoint=checkpoint ) - @property - def latest_checkpoint_dir(self) -> Optional[Path]: - raise NotImplementedError - @DeveloperAPI class DataParallelTrainer(Trainer): diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index d95bbeaf0988..a654ad10b347 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -135,7 +135,7 @@ def _get_next_checkpoint_path(self) -> Optional[Path]: def on_start_training( self, checkpoint_strategy: Optional[CheckpointStrategy], - run_dir: str, + run_dir: Path, latest_checkpoint_id: Optional[int] = 0, ): checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() @@ -222,6 +222,17 @@ def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): return super()._process_persistent_checkpoint(checkpoint) + @property + def latest_checkpoint_dir(self) -> Optional[Path]: + raise NotImplementedError + + @property + def next_checkpoint_path(self) -> Optional[Path]: + return None + + def _get_next_checkpoint_path(self) -> Optional[Path]: + return None + def _construct_checkpoint_path_name(checkpoint_id: int) -> str: return f"checkpoint_{checkpoint_id:06d}" diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index f6135e6c52da..7de5560fabfb 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -537,7 +537,7 @@ def train_func(): assert trainer.best_checkpoint["loss"] == 3 checkpoint_dir = trainer.latest_checkpoint_dir - file_names = [f.name for f in checkpoint_dir.iterdir()] + file_names = [f.name for f in checkpoint_dir.iterdir() if f.is_dir()] assert len(file_names) == 2 assert f"checkpoint_{2:06d}" in file_names assert f"checkpoint_{3:06d}" not in file_names diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index 85a295e2db7b..d4156385f544 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -3,7 +3,7 @@ import pytest import ray import ray.train as train -from ray import tune, cloudpickle +from ray import tune from ray.ml import Checkpoint from ray.tune import TuneError from ray.train import Trainer From b8b944d87c52b536a28fd540ace6f80ce0384542 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 14:26:10 +0100 Subject: [PATCH 71/81] Allow None values for memory checkpoints --- python/ray/util/ml_utils/checkpoint_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py index aae410c259d0..8a54565b9cdb 100644 --- a/python/ray/util/ml_utils/checkpoint_manager.py +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -72,8 +72,10 @@ def __init__( self.metrics = metrics or {} self.node_ip = node_ip or self.metrics.get(NODE_IP, None) - if storage_mode == CheckpointStorage.MEMORY and not isinstance( - dir_or_data, (dict, ray.ObjectRef) + if ( + dir_or_data is not None + and storage_mode == CheckpointStorage.MEMORY + and not isinstance(dir_or_data, (dict, ray.ObjectRef)) ): raise ValueError( f"Memory checkpoints only support Ray object references and dicts " From 60eb3a2614de1baa2e18edf9cbad059d2d5c28da Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 2 Jun 2022 17:35:40 +0100 Subject: [PATCH 72/81] Fix huggingface trainer --- .../ml/train/integrations/huggingface/huggingface_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 46cd3b54d1d0..ee0ae4950b49 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -70,7 +70,7 @@ def commit(self, path: Optional[Path] = None) -> None: # Move contents of source_path, but not source_path # itself. shutil.move is already recursive. for inner in Path(source_path).iterdir(): - shutil.move(str(path.absolute()), inner) + shutil.move(str(inner.absolute()), str(path)) shutil.rmtree(source_path, ignore_errors=True) else: sync_dir_between_nodes( From 402e1ee52ef4a62f020792b2f97a10ee171c6fc6 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 3 Jun 2022 10:37:28 +0100 Subject: [PATCH 73/81] result -> metrics --- python/ray/tune/trial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 815abc807cf4..f91df6f1dc60 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -667,7 +667,7 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint): def on_restore(self): """Handles restoration completion.""" assert self.is_restoring - self.last_result = self.restoring_from.result + self.last_result = self.restoring_from.metrics self.restoring_from = None self.invalidate_json_state() From 9f6dca47234f1c74ba6944db1f0527ec259f7c4d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 3 Jun 2022 22:38:07 +0100 Subject: [PATCH 74/81] ml -> air --- python/ray/train/_checkpoint.py | 2 +- python/ray/train/tests/test_tune.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index a654ad10b347..baa8003108fa 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Optional, Dict, Union, Callable -from ray.ml import Checkpoint +from ray.air import Checkpoint from ray.train.constants import ( TIMESTAMP, TRAIN_CHECKPOINT_SUBDIR, diff --git a/python/ray/train/tests/test_tune.py b/python/ray/train/tests/test_tune.py index d4156385f544..2ddd73bbdbfa 100644 --- a/python/ray/train/tests/test_tune.py +++ b/python/ray/train/tests/test_tune.py @@ -4,7 +4,7 @@ import ray import ray.train as train from ray import tune -from ray.ml import Checkpoint +from ray.air import Checkpoint from ray.tune import TuneError from ray.train import Trainer from ray.train.backend import Backend, BackendConfig From 228dd275397652202ab9f5be167120314c764e62 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 6 Jun 2022 10:43:56 +0100 Subject: [PATCH 75/81] Merge conflicts --- python/ray/air/train/data_parallel_trainer.py | 1 + .../train/integrations/huggingface/huggingface_trainer.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/air/train/data_parallel_trainer.py b/python/ray/air/train/data_parallel_trainer.py index c4e19b01ecc6..c45ee301d358 100644 --- a/python/ray/air/train/data_parallel_trainer.py +++ b/python/ray/air/train/data_parallel_trainer.py @@ -32,6 +32,7 @@ from ray.util.annotations import DeveloperAPI from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint + logger = logging.getLogger(__name__) diff --git a/python/ray/air/train/integrations/huggingface/huggingface_trainer.py b/python/ray/air/train/integrations/huggingface/huggingface_trainer.py index d77bb60baebb..010b7db35ff4 100644 --- a/python/ray/air/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/air/train/integrations/huggingface/huggingface_trainer.py @@ -20,7 +20,11 @@ from ray.util import PublicAPI, get_node_ip_address from ray.air.checkpoint import Checkpoint from ray.air.config import RunConfig, ScalingConfig, DatasetConfig -from ray.air.constants import EVALUATION_DATASET_KEY, TRAIN_DATASET_KEY +from ray.air.constants import ( + EVALUATION_DATASET_KEY, + TRAIN_DATASET_KEY, + PREPROCESSOR_KEY, +) from ray.air.preprocessor import Preprocessor from ray.air.train.integrations.torch import TorchTrainer from ray.air.trainer import GenDataset From c92f293dcffebcf7c92af2ab58ab9cdd49e2a2dd Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 6 Jun 2022 10:54:06 +0100 Subject: [PATCH 76/81] Fix CheckpointStorage --- python/ray/tune/integration/comet.py | 2 +- python/ray/tune/integration/wandb.py | 2 +- python/ray/tune/ray_trial_executor.py | 4 ++-- python/ray/tune/tests/test_trial_scheduler.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/integration/comet.py b/python/ray/tune/integration/comet.py index f675fe25ae7c..3791072346d2 100644 --- a/python/ray/tune/integration/comet.py +++ b/python/ray/tune/integration/comet.py @@ -224,7 +224,7 @@ def log_trial_save(self, trial: "Trial"): ) # Walk through checkpoint directory and add all files to artifact - checkpoint_root = trial.checkpoint.value + checkpoint_root = trial.checkpoint.dir_or_data for root, dirs, files in os.walk(checkpoint_root): rel_root = os.path.relpath(root, checkpoint_root) for file in files: diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 2be77200cc7b..4821ad6816ca 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -425,7 +425,7 @@ def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): def log_trial_save(self, trial: "Trial"): if self.save_checkpoints and trial.checkpoint: self._trial_queues[trial].put( - (_QueueItem.CHECKPOINT, trial.checkpoint.value) + (_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data) ) def log_trial_end(self, trial: "Trial", failed: bool = False): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 4a139e2978aa..e420b0c4fb78 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -608,7 +608,7 @@ def pause_trial(self, trial: Trial) -> None: """ assert trial.status == Trial.RUNNING, trial.status try: - self.save(trial, _TuneCheckpoint.MEMORY) + self.save(trial, CheckpointStorage.MEMORY) self.stop_trial(trial) self.set_status(trial, Trial.PAUSED) except Exception: @@ -724,7 +724,7 @@ def force_reconcilation_on_next_step_end(self) -> None: def save( self, trial: Trial, - storage: str = CheckpointStorage.PERSISTENT, + storage: CheckpointStorage = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> _TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index cc04ab13a658..ee51287c8742 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -239,7 +239,7 @@ def result2(t, rew): class _MockTrialExecutor(RayTrialExecutor): def start_trial(self, trial, checkpoint_obj=None, train=True): trial.logger_running = True - trial.restored_checkpoint = checkpoint_obj.value + trial.restored_checkpoint = checkpoint_obj.dir_or_data trial.status = Trial.RUNNING return True From 5411958d7b789ee0082b1cfbf6ba2a8f0b0addfc Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 6 Jun 2022 15:46:07 +0100 Subject: [PATCH 77/81] Fix tests with strings in memory checkpoints --- python/ray/tune/tests/test_trial_scheduler.py | 7 +++++-- python/ray/tune/tests/test_trial_scheduler_pbt.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index ee51287c8742..af66ac3dbdc3 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -843,12 +843,15 @@ def __init__(self, i, config): self._default_result_or_future = None def on_checkpoint(self, checkpoint): - self.restored_checkpoint = checkpoint.dir_or_data + if checkpoint.storage_mode == CheckpointStorage.MEMORY: + self.restored_checkpoint = checkpoint.dir_or_data["data"] + else: + self.restored_checkpoint = checkpoint.dir_or_data @property def checkpoint(self): return _TrackedCheckpoint( - dir_or_data=self.trainable_name, + dir_or_data={"data": self.trainable_name}, storage_mode=CheckpointStorage.MEMORY, metrics=None, ) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 14417d3731f2..ddc393d72d23 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -441,7 +441,7 @@ class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( - dir_or_data="None", + dir_or_data={"data": "None"}, storage_mode=CheckpointStorage.MEMORY, metrics={}, ) From 73ea9e554ee49e442f2802ebea60ef067985c622 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 7 Jun 2022 09:59:56 +0100 Subject: [PATCH 78/81] Address comments --- .../air/train/integrations/huggingface/huggingface_trainer.py | 3 +++ python/ray/train/__init__.py | 2 +- python/ray/train/trainer.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/ray/air/train/integrations/huggingface/huggingface_trainer.py b/python/ray/air/train/integrations/huggingface/huggingface_trainer.py index 010b7db35ff4..882f677e594c 100644 --- a/python/ray/air/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/air/train/integrations/huggingface/huggingface_trainer.py @@ -57,6 +57,9 @@ # TODO(ml-team): Make dir syncing checkpoint logic generic. +# The checkpoint is turned into a dict with node ip & path +# in HuggingFaceTrainer.as_trainable +# TODO(team-ml): Refactor checkpoint management along with Tune. class _SyncedTrackedCheckpoint(_TrackedCheckpoint): def commit(self, path: Optional[Path] = None) -> None: if ( diff --git a/python/ray/train/__init__.py b/python/ray/train/__init__.py index b5fa13c0ca37..e84d4561173f 100644 --- a/python/ray/train/__init__.py +++ b/python/ray/train/__init__.py @@ -1,6 +1,5 @@ from ray.train.backend import BackendConfig from ray.train.callbacks import TrainingCallback -from ray.train._checkpoint import CheckpointStrategy from ray.train.session import ( get_dataset_shard, local_rank, @@ -11,6 +10,7 @@ world_size, ) from ray.train.trainer import Trainer, TrainingIterator +from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy from ray._private.usage import usage_lib diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 220d55663b68..7ab29bd65959 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -23,7 +23,6 @@ ActorWrapper, ) from ray.train._checkpoint import ( - CheckpointStrategy, TuneCheckpointManager, CheckpointManager, load_checkpoint_from_path, @@ -42,6 +41,7 @@ from ray.train.worker_group import WorkerGroup from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy if TUNE_INSTALLED: from ray import tune From fceb5c30272629404f3adcc2bd0dd9f868c0597c Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 7 Jun 2022 11:28:20 +0100 Subject: [PATCH 79/81] Install pandas in minimal env --- ci/env/install-minimal.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/env/install-minimal.sh b/ci/env/install-minimal.sh index e2721b001ff0..bd7d8947d90c 100755 --- a/ci/env/install-minimal.sh +++ b/ci/env/install-minimal.sh @@ -34,4 +34,5 @@ eval "${WORKSPACE_DIR}/ci/ci.sh build" # Install test requirements python -m pip install -U \ pytest==5.4.3 \ - numpy + numpy \ + pandas From 5a670b1fbcaca0c18862d4878faa27d3637a2a35 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 7 Jun 2022 11:41:18 +0100 Subject: [PATCH 80/81] Type checking import --- ci/env/install-minimal.sh | 3 +-- python/ray/air/predictor.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ci/env/install-minimal.sh b/ci/env/install-minimal.sh index bd7d8947d90c..e2721b001ff0 100755 --- a/ci/env/install-minimal.sh +++ b/ci/env/install-minimal.sh @@ -34,5 +34,4 @@ eval "${WORKSPACE_DIR}/ci/ci.sh build" # Install test requirements python -m pip install -U \ pytest==5.4.3 \ - numpy \ - pandas + numpy diff --git a/python/ray/air/predictor.py b/python/ray/air/predictor.py index 495612185f22..a4d76f97c54a 100644 --- a/python/ray/air/predictor.py +++ b/python/ray/air/predictor.py @@ -4,13 +4,14 @@ from ray.air.checkpoint import Checkpoint from ray.util.annotations import DeveloperAPI, PublicAPI -import numpy as np -import pandas as pd - if TYPE_CHECKING: + import numpy as np + import pandas as pd import pyarrow -DataBatchType = Union[np.ndarray, pd.DataFrame, "pyarrow.Table", Dict[str, np.ndarray]] +DataBatchType = Union[ + "np.ndarray", "pd.DataFrame", "pyarrow.Table", Dict[str, "np.ndarray"] +] @PublicAPI(stability="alpha") @@ -95,7 +96,7 @@ def predict(self, data: DataBatchType, **kwargs) -> DataBatchType: # return _convert_pandas_to_batch_type(predictions_df, type=type(data)) @DeveloperAPI - def _predict_pandas(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame: + def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame": """Perform inference on a Pandas DataFrame. All predictors are expected to implement this method. @@ -118,13 +119,13 @@ def __reduce__(self): ) -def _convert_batch_type_to_pandas(data: DataBatchType) -> pd.DataFrame: +def _convert_batch_type_to_pandas(data: DataBatchType) -> "pd.DataFrame": """Convert the provided data to a Pandas DataFrame.""" pass def _convert_pandas_to_batch_type( - data: pd.DataFrame, type: Type[DataBatchType] + data: "pd.DataFrame", type: Type[DataBatchType] ) -> DataBatchType: """Convert the provided Pandas dataframe to the provided ``type``.""" From 74cabb23d17330fb192c9203e73d7170337d29e9 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 7 Jun 2022 11:48:52 +0100 Subject: [PATCH 81/81] Fix example --- python/ray/tune/examples/tf_mnist_example.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/examples/tf_mnist_example.py b/python/ray/tune/examples/tf_mnist_example.py index 1468a18d6d90..e2bde13e4106 100644 --- a/python/ray/tune/examples/tf_mnist_example.py +++ b/python/ray/tune/examples/tf_mnist_example.py @@ -109,10 +109,10 @@ def step(self): # It is important to return tf.Tensors as numpy objects. return { "epoch": self.iteration, - "loss": self.train_loss.metrics().numpy(), - "accuracy": self.train_accuracy.metrics().numpy() * 100, - "test_loss": self.test_loss.metrics().numpy(), - "mean_accuracy": self.test_accuracy.metrics().numpy() * 100, + "loss": self.train_loss.result().numpy(), + "accuracy": self.train_accuracy.result().numpy() * 100, + "test_loss": self.test_loss.result().numpy(), + "mean_accuracy": self.test_accuracy.result().numpy() * 100, }