Skip to content

Commit

Permalink
[air/tf] Support TensorflowCheckpoint's saved model/h5 format (#28474)
Browse files Browse the repository at this point in the history
  • Loading branch information
xwjiang2010 authored Oct 5, 2022
1 parent d80c589 commit 0cc4b65
Show file tree
Hide file tree
Showing 11 changed files with 474 additions and 82 deletions.
6 changes: 2 additions & 4 deletions doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@
"outputs": [],
"source": [
"from ray.air import session, Checkpoint\n",
"from ray.train.tensorflow import prepare_dataset_shard\n",
"from ray.train.tensorflow import prepare_dataset_shard, TensorflowCheckpoint\n",
"\n",
"def train_loop_per_worker():\n",
" dataset_shard = session.get_dataset_shard(\"train\")\n",
Expand Down Expand Up @@ -721,9 +721,7 @@
" # This saves checkpoint in a way that can be used by Ray Serve coherently.\n",
" session.report(\n",
" {},\n",
" checkpoint=Checkpoint.from_dict(\n",
" dict(epoch=epoch, model=model.get_weights())\n",
" ),\n",
" checkpoint=TensorflowCheckpoint.from_model(model),\n",
" )"
]
},
Expand Down
5 changes: 2 additions & 3 deletions python/ray/air/callbacks/keras.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from collections import Counter
from typing import Dict, List, Optional, Union

from ray.air.constants import MODEL_KEY
from tensorflow.keras.callbacks import Callback as KerasCallback

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.train.tensorflow import TensorflowCheckpoint
from ray.util.annotations import PublicAPI


Expand Down Expand Up @@ -176,7 +175,7 @@ def _handle(self, logs: Dict, when: str = None):

checkpoint = None
if freq > 0 and self._counter[when] % freq == 0:
checkpoint = Checkpoint.from_dict({MODEL_KEY: self.model.get_weights()})
checkpoint = TensorflowCheckpoint.from_model(self.model)

if not self._metrics:
report_dict = logs
Expand Down
6 changes: 3 additions & 3 deletions python/ray/air/tests/test_keras_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import ray
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.air.constants import MODEL_KEY
from ray.train.constants import TRAIN_DATASET_KEY
from ray.air.config import ScalingConfig
from ray.train.tensorflow import (
TensorflowCheckpoint,
TensorflowTrainer,
prepare_dataset_shard,
TensorflowPredictor,
Expand Down Expand Up @@ -79,8 +79,8 @@ def test_keras_callback_e2e():
datasets={TRAIN_DATASET_KEY: get_dataset()},
)
checkpoint = trainer.fit().checkpoint
checkpoint_dict = checkpoint.to_dict()
assert MODEL_KEY in checkpoint_dict
assert isinstance(checkpoint, TensorflowCheckpoint)
assert checkpoint._flavor == TensorflowCheckpoint.Flavor.MODEL_WEIGHTS

predictor = TensorflowPredictor.from_checkpoint(
checkpoint, model_definition=build_model
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_tensorflow_checkpoint",
size = "small",
srcs = ["tests/test_tensorflow_checkpoint.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
)

py_test(
name = "test_tensorflow_predictor",
size = "small",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,10 @@ def _datasets_repr_(self) -> str:
return VBox(content, layout=Layout(width="100%"))


def _load_checkpoint(
def _load_checkpoint_dict(
checkpoint: Checkpoint, trainer_name: str
) -> Tuple[Any, Optional["Preprocessor"]]:
"""Load a Ray Train Checkpoint.
"""Load a Ray Train Checkpoint (dict based).
This is a private API.
Expand Down
223 changes: 216 additions & 7 deletions python/ray/train/tensorflow/tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import TYPE_CHECKING, Optional
import os
from typing import TYPE_CHECKING, Callable, Optional, Type, Union

from enum import Enum
from os import path
import tensorflow as tf
from tensorflow import keras
import warnings

from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.train.data_parallel_trainer import _load_checkpoint_dict
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand All @@ -21,6 +26,25 @@ class TensorflowCheckpoint(Checkpoint):
``TensorflowCheckpoint.from_checkpoint(ckpt)``.
"""

_SERIALIZED_ATTRS = ("_flavor", "_h5_file_path")

class Flavor(Enum):
# Various flavors with which TensorflowCheckpoint is generated.
# This is necessary metadata to decide how to load model from a checkpoint.
MODEL_WEIGHTS = 1
SAVED_MODEL = 2
H5 = 3

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._flavor = None
# Will only be set when `self._flavor` is `H5`.
self._h5_file_path = None

@classmethod
def from_model(
cls,
Expand All @@ -31,8 +55,11 @@ def from_model(
"""Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras
model.
The checkpoint created with this method needs to be paired with
`model_definition` when used.
Args:
model: The Keras model to store in the checkpoint.
model: The Keras model, whose weights are stored in the checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
Expand All @@ -48,9 +75,191 @@ def from_model(
checkpoint = cls.from_dict(
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()}
)
checkpoint._flavor = cls.Flavor.MODEL_WEIGHTS
return checkpoint

def get_model_weights(self) -> tf.keras.Model:
"""Retrieve the model weights stored in this checkpoint."""
model_weights, _ = _load_checkpoint(self, "TensorflowTrainer")
return model_weights
@classmethod
def from_h5(
cls, file_path: str, *, preprocessor: Optional["Preprocessor"] = None
) -> "TensorflowCheckpoint":
"""Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras
model from H5 format.
The checkpoint generated by this method contains all the information needed.
Thus no `model_definition` is needed to be supplied when using this checkpoint.
`file_path` must maintain validity even after this function returns.
Some new files/directories may be added to the parent directory of `file_path`,
as a side effect of this method.
Args:
file_path: The path to the .h5 file to load model from. This is the
same path that is used for ``model.save(path)``.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
A :py:class:`TensorflowCheckpoint` converted from h5 format.
Examples:
>>> import tensorflow as tf
>>> import ray
>>> from ray.train.batch_predictor import BatchPredictor
>>> from ray.train.tensorflow import (
... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor
... )
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> def train_func():
... model = tf.keras.Sequential(
... [
... tf.keras.layers.InputLayer(input_shape=()),
... tf.keras.layers.Flatten(),
... tf.keras.layers.Dense(10),
... tf.keras.layers.Dense(1),
... ]
... )
... model.save("my_model.h5")
... checkpoint = TensorflowCheckpoint.from_h5("my_model.h5")
... session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer(
... train_loop_per_worker=train_func,
... scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint # doctest: +SKIP
>>> batch_predictor = BatchPredictor.from_checkpoint(
... result_checkpoint, TensorflowPredictor) # doctest: +SKIP
>>> batch_predictor.predict(ray.data.range(3)) # doctest: +SKIP
"""
if not path.isfile(file_path) or not file_path.endswith(".h5"):
raise ValueError(
"Please supply a h5 file path to `TensorflowCheckpoint.from_h5()`."
)
dir_path = path.dirname(os.path.abspath(file_path))
if preprocessor:
save_preprocessor_to_dir(preprocessor, dir_path)
checkpoint = cls.from_directory(dir_path)
checkpoint._flavor = cls.Flavor.H5
checkpoint._h5_file_path = os.path.basename(file_path)
return checkpoint

@classmethod
def from_saved_model(
cls, dir_path: str, *, preprocessor: Optional["Preprocessor"] = None
) -> "TensorflowCheckpoint":
"""Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras
model from SavedModel format.
The checkpoint generated by this method contains all the information needed.
Thus no `model_definition` is needed to be supplied when using this checkpoint.
`dir_path` must maintain validity even after this function returns.
Some new files/directories may be added to `dir_path`, as a side effect
of this method.
Args:
dir_path: The directory containing the saved model. This is the same
directory as used by ``model.save(dir_path)``.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
A :py:class:`TensorflowCheckpoint` converted from SavedModel format.
Examples:
>>> import tensorflow as tf
>>> import ray
>>> from ray.train.batch_predictor import BatchPredictor
>>> from ray.train.tensorflow import (
... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor)
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> def train_fn():
... model = tf.keras.Sequential(
... [
... tf.keras.layers.InputLayer(input_shape=()),
... tf.keras.layers.Flatten(),
... tf.keras.layers.Dense(10),
... tf.keras.layers.Dense(1),
... ])
... model.save("my_model")
... checkpoint = TensorflowCheckpoint.from_saved_model("my_model")
... session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer(
... train_loop_per_worker=train_fn,
... scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint # doctest: +SKIP
>>> batch_predictor = BatchPredictor.from_checkpoint(
... result_checkpoint, TensorflowPredictor) # doctest: +SKIP
>>> batch_predictor.predict(ray.data.range(3)) # doctest: +SKIP
"""
if preprocessor:
save_preprocessor_to_dir(preprocessor, dir_path)
checkpoint = cls.from_directory(dir_path)
checkpoint._flavor = cls.Flavor.SAVED_MODEL
return checkpoint

def get_model(
self,
model_definition: Optional[
Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]]
] = None,
) -> tf.keras.Model:
"""Retrieve the model stored in this checkpoint.
Args:
model_definition: This arg is expected only if the original checkpoint
was created via `TensorflowCheckpoint.from_model`.
Returns:
The Tensorflow Keras model stored in the checkpoint.
"""
if self._flavor is self.Flavor.MODEL_WEIGHTS:
if not model_definition:
raise ValueError(
"Expecting `model_definition` argument when checkpoint is "
"saved through `TensorflowCheckpoint.from_model()`."
)
model_weights, _ = _load_checkpoint_dict(self, "TensorflowTrainer")
model = model_definition()
model.set_weights(model_weights)
return model
else:
if model_definition:
warnings.warn(
"TensorflowCheckpoint was created from "
"TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`, which already contains all the "
"information needed. This means: "
"If you are using BatchPredictor, you should do "
"`BatchPredictor.from_checkpoint(checkpoint, TensorflowPredictor)`"
" by removing kwargs `model_definition=`. "
"If you are using TensorflowPredictor directly, you should do "
"`TensorflowPredictor.from_checkpoint(checkpoint)` by "
"removing kwargs `model_definition=`."
)
with self.as_directory() as checkpoint_dir:
if self._flavor == self.Flavor.H5:
return keras.models.load_model(
os.path.join(checkpoint_dir, self._h5_file_path)
)
elif self._flavor == self.Flavor.SAVED_MODEL:
return keras.models.load_model(checkpoint_dir)
else:
raise RuntimeError(
"Avoid directly using `from_dict` or "
"`from_directory` directly. Make sure "
"that the checkpoint was generated by "
"`TensorflowCheckpoint.from_model`, "
"`TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`."
)
Loading

0 comments on commit 0cc4b65

Please sign in to comment.