Skip to content

Commit

Permalink
[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints (
Browse files Browse the repository at this point in the history
…ray-project#28794)

This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in session.py if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary.

Signed-off-by: Antoni Baum <[email protected]>
Co-authored-by: Kai Fricke <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
2 people authored and WeichenXu123 committed Dec 19, 2022
1 parent 86e14e5 commit cbce66d
Show file tree
Hide file tree
Showing 24 changed files with 404 additions and 174 deletions.
20 changes: 20 additions & 0 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,26 @@ def convert_ndarray_batch_to_tf_tensor_batch(
return batch


# This is not foolproof, but it's better than nothing
# The place it is used in will be deprecated soon
def contains_tensorflow_object(obj):
if hasattr(obj, "__module__") and (
"keras" in obj.__module__ or "tensorflow" in obj.__module__
):
return True
elif isinstance(obj, dict):
for k, v in obj.items():
if contains_tensorflow_object(k):
return True
if contains_tensorflow_object(v):
return True
elif isinstance(obj, (list, tuple)):
for v in obj:
if contains_tensorflow_object(v):
return True
return False


def get_type_spec(
schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
columns: Union[str, List[str]],
Expand Down
9 changes: 6 additions & 3 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def _get_temporary_checkpoint_dir(self) -> str:
)
return os.path.join(tmp_dir_path, checkpoint_dir_name)

def _save_checkpoint_metadata_in_directory(self, path: str) -> None:
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def _to_directory(self, path: str) -> None:
if self._data_dict or self._obj_ref:
# This is a object ref or dict
Expand Down Expand Up @@ -547,9 +552,7 @@ def _to_directory(self, path: str) -> None:
f"No valid location found for checkpoint {self}: {self._uri}"
)

checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)
self._save_checkpoint_metadata_in_directory(path)

def to_directory(self, path: Optional[str] = None) -> str:
"""Write checkpoint data to directory.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def initialize_session(
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint,
encode_data_fn=self._backend.encode_data,
encode_data_fn=self._backend._encode_data,
)
)

Expand Down
12 changes: 9 additions & 3 deletions python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ def _process_checkpoint(
"""Ray Train entrypoint. Perform all processing for a checkpoint."""
# Get checkpoint from first worker.
checkpoint_data = checkpoint_results[0].data
checkpoint_metadata = checkpoint_results[0].metadata or {}

# Decode checkpoint.
checkpoint_data = decode_checkpoint_fn(checkpoint_data)
# TODO(ml-team): Remove once we remove Backend.decode_data
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict()

score_attr = self._checkpoint_strategy.checkpoint_score_attribute
if (
self._checkpoint_strategy.num_to_keep != 0
and score_attr not in checkpoint_metadata
and score_attr not in checkpoint_data
):
raise ValueError(
Expand All @@ -122,7 +124,11 @@ def _process_checkpoint(
dir_or_data=checkpoint_data,
checkpoint_id=self._latest_checkpoint_id,
storage_mode=CheckpointStorage.MEMORY,
metrics={score_attr: checkpoint_data.get(score_attr, 0.0)},
metrics={
score_attr: checkpoint_metadata.get(
score_attr, checkpoint_data.get(score_attr, 0.0)
)
},
)
self.register_checkpoint(checkpoint=tracked_checkpoint)

Expand Down
51 changes: 36 additions & 15 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import platform
import queue
import sys
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -51,7 +52,8 @@ class TrialInfo:
@dataclass
class TrainingResult:
type: TrainingResultType
data: Dict
data: Union[Dict, Checkpoint]
metadata: Optional[Dict] = None


# TODO(xwjiang): This needs a better name.
Expand All @@ -68,8 +70,9 @@ def __init__(
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None,
checkpoint: Optional[Checkpoint] = None,
# Deprecated
encode_data_fn: Optional[Callable] = None,
detailed_autofilled_metrics: bool = False,
):

Expand All @@ -80,7 +83,7 @@ def __init__(
self.world_size = world_size
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint
self.loaded_checkpoint = checkpoint

# Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn:
Expand Down Expand Up @@ -240,9 +243,9 @@ def _report_legacy(self, **kwargs):
if self.ignore_report:
return

kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs))
kwargs = self._auto_fill_metrics(kwargs)

result = TrainingResult(TrainingResultType.REPORT, kwargs)
result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)
Expand All @@ -269,22 +272,26 @@ def _report_thread_runner_error(self, block=False):
except queue.Empty:
pass

def checkpoint(self, **kwargs):
def checkpoint(self, checkpoint: Checkpoint):
"""Adds kwargs to the queue to be consumed by main thread.
Also stores the checkpoint in ``self.loaded_checkpoint``.
"""

# Update session checkpoint to latest checkpoint.
self.loaded_checkpoint = kwargs
self.loaded_checkpoint = checkpoint

# Only store checkpoints on worker with rank 0.
if self.world_rank != 0:
kwargs = {}
else:
kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs))

result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
checkpoint = None
elif checkpoint:
checkpoint = self._encode_data_fn(checkpoint)

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=checkpoint,
metadata=self._auto_fill_checkpoint_metrics({}),
)
# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

Expand All @@ -294,9 +301,23 @@ def checkpoint(self, **kwargs):

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

# Special case: early fail for Torch tensors
if "torch" in sys.modules:
from ray.air._internal.torch_utils import contains_tensor

if contains_tensor(metrics):
raise ValueError(
"Passing objects containg Torch tensors as metrics "
"is not supported as it will throw an exception on "
"deserialization. You can either convert the tensors "
"to Python objects or use a `TorchCheckpoint` as the "
"`checkpoint` argument of `ray.air.session.report` to "
"store your Torch objects."
)

if checkpoint:
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self.checkpoint(checkpoint)
self._report_legacy(**metrics)


Expand Down
65 changes: 62 additions & 3 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
import logging
from typing import TypeVar, Dict
import warnings
from typing import Type, TypeVar, Dict

from ray.air.checkpoint import Checkpoint
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI

from ray.util.annotations import Deprecated, DeveloperAPI
from ray.widgets import make_table_html_repr

EncodedData = TypeVar("EncodedData")

logger = logging.getLogger(__name__)

# This is used in several places to print a warning.
_encode_decode_deprecation_message = (
"``encode_data`` and ``decode_data`` are deprecated in favor of "
"framework-specific ``ray.air.Checkpoint`` subclasses (reported "
"using ``ray.air.session.report()``) which can implement "
"encoding and decoding logic. In the future, ``encode_data`` and "
"``decode_data`` will throw an exception if overriden."
)


def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]):
return
# Do not print warnings in 2.1 yet.
# TODO(ml-team): Change this once we have full API parity with framework
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning
# warnings.warn(
# f"You have reported a checkpoint with the `{Checkpoint}` "
# "type, but the intended checkpoint type for the Trainer "
# f"you are using is `{expected_checkpoint_cls}`. "
# "Not using the intended checkpoint type may cause "
# "exceptions or other issues, especially during "
# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )


@DeveloperAPI
class BackendConfig:
Expand Down Expand Up @@ -46,6 +73,37 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass

@classmethod
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``encode_data`` is deprecated."""
if cls.encode_data != Backend.encode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
# We wrap the return of encode_data in dict in case it is
# not a dict itself.
checkpoint = checkpoint.from_dict(
{"encoded_data": cls.encode_data(checkpoint.to_dict())}
)
return checkpoint

@classmethod
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``decode_data`` is deprecated."""
if cls.decode_data != Backend.decode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
checkpoint_dict = checkpoint.to_dict()
# If "encoded_data" is not in the dict, then the data was
# not encoded, but the user may want to just do decoding
# anyway.
checkpoint = checkpoint.from_dict(
cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict))
)
return checkpoint

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
Expand All @@ -56,6 +114,7 @@ def encode_data(data_dict: Dict) -> EncodedData:

return data_dict

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
Expand Down Expand Up @@ -41,7 +42,10 @@ def __init__(
)

def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint):
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
if isinstance(checkpoint.dir_or_data, dict):
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
else:
save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data)
super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint(
checkpoint=checkpoint
)
Expand Down
9 changes: 4 additions & 5 deletions python/ray/train/examples/horovod/horovod_pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from torchvision import datasets, transforms

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.horovod import HorovodTrainer
from ray.train.torch.torch_checkpoint import TorchCheckpoint
import ray.train.torch


Expand Down Expand Up @@ -152,12 +152,11 @@ def train_func(config):
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda
)
if save_model_as_dict:
checkpoint_dict = dict(model=model.state_dict())
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
else:
checkpoint_dict = dict(model=model)
checkpoint_dict = Checkpoint.from_dict(checkpoint_dict)
checkpoint = TorchCheckpoint.from_model(model)
results.append(loss)
session.report(dict(loss=loss), checkpoint=checkpoint_dict)
session.report(dict(loss=loss), checkpoint=checkpoint)

# Only used for testing.
return results
Expand Down
11 changes: 5 additions & 6 deletions python/ray/train/examples/pytorch/torch_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import ray.train as train
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig


Expand Down Expand Up @@ -48,8 +48,7 @@ def validate_epoch(dataloader, model, loss_fn):
import copy

model_copy = copy.deepcopy(model)
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result
return model_copy.cpu().state_dict(), loss


def train_func(config):
Expand All @@ -76,12 +75,12 @@ def train_func(config):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

results = []

for _ in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
result = validate_epoch(validation_loader, model, loss_fn)
state_dict, loss = validate_epoch(validation_loader, model, loss_fn)
result = dict(loss=loss)
results.append(result)
session.report(result)
session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict))

return results

Expand Down
Loading

0 comments on commit cbce66d

Please sign in to comment.