Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Fix storage client creation when sync function tpl is not provided (#26714) #26717

Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def get_state(self):
"ray_version": ray.__version__,
}

def save(self, checkpoint_dir: Optional[str] = None) -> str:
def save(self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = False) -> str:
"""Saves the current model state to a checkpoint.

Subclasses should override ``save_checkpoint()`` instead to save state.
Expand All @@ -440,6 +440,7 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str:

Args:
checkpoint_dir: Optional dir to place the checkpoint.
prevent_upload: bool flag to stop tmp folders from uploading

Returns:
str: path that points to xxx.pkl file.
Expand All @@ -459,7 +460,8 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str:
)

# Maybe sync to cloud
self._maybe_save_to_cloud(checkpoint_dir)
if not prevent_upload:
self._maybe_save_to_cloud(checkpoint_dir)

return checkpoint_path

Expand All @@ -486,12 +488,12 @@ def save_to_object(self):
"""Saves the current model state to a Python object.

It also saves to disk but does not return the checkpoint path.

It doesn't save to cloud.
Returns:
Object holding checkpoint data.
"""
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_path = self.save(tmpdir)
checkpoint_path = self.save(tmpdir, prevent_upload=True)
# Save all files in subtree and delete the tmpdir.
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
shutil.rmtree(tmpdir)
Expand Down Expand Up @@ -544,6 +546,11 @@ def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None
# Only keep for backwards compatibility
self.storage_client.sync_down(external_uri, local_dir)
self.storage_client.wait_or_retry()
elif os.path.exists(checkpoint_path):
try:
TrainableUtil.find_checkpoint_dir(checkpoint_path)
except Exception:
pass
rohit-annigeri marked this conversation as resolved.
Show resolved Hide resolved
else:
checkpoint = Checkpoint.from_uri(external_uri)
retry_fn(
Expand Down