From 0cc4b653134812585da537e90f53b739654602eb Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Wed, 5 Oct 2022 00:54:15 -0700 Subject: [PATCH] [air/tf] Support TensorflowCheckpoint's saved model/h5 format (#28474) --- .../examples/tfx_tabular_train_to_serve.ipynb | 6 +- python/ray/air/callbacks/keras.py | 5 +- python/ray/air/tests/test_keras_callback.py | 6 +- python/ray/train/BUILD | 8 + python/ray/train/data_parallel_trainer.py | 4 +- .../train/tensorflow/tensorflow_checkpoint.py | 223 +++++++++++++++++- .../train/tensorflow/tensorflow_predictor.py | 44 ++-- .../train/tests/test_tensorflow_checkpoint.py | 172 ++++++++++++++ .../train/tests/test_tensorflow_predictor.py | 66 +++--- .../train/tests/test_tensorflow_trainer.py | 18 +- python/ray/train/torch/torch_checkpoint.py | 4 +- 11 files changed, 474 insertions(+), 82 deletions(-) create mode 100644 python/ray/train/tests/test_tensorflow_checkpoint.py diff --git a/doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb b/doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb index 67d720bbc2b6..dcd28367426d 100644 --- a/doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb +++ b/doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb @@ -681,7 +681,7 @@ "outputs": [], "source": [ "from ray.air import session, Checkpoint\n", - "from ray.train.tensorflow import prepare_dataset_shard\n", + "from ray.train.tensorflow import prepare_dataset_shard, TensorflowCheckpoint\n", "\n", "def train_loop_per_worker():\n", " dataset_shard = session.get_dataset_shard(\"train\")\n", @@ -721,9 +721,7 @@ " # This saves checkpoint in a way that can be used by Ray Serve coherently.\n", " session.report(\n", " {},\n", - " checkpoint=Checkpoint.from_dict(\n", - " dict(epoch=epoch, model=model.get_weights())\n", - " ),\n", + " checkpoint=TensorflowCheckpoint.from_model(model),\n", " )" ] }, diff --git a/python/ray/air/callbacks/keras.py b/python/ray/air/callbacks/keras.py index 43919fba6f7f..d49d850d5361 100644 --- a/python/ray/air/callbacks/keras.py +++ b/python/ray/air/callbacks/keras.py @@ -1,11 +1,10 @@ from collections import Counter from typing import Dict, List, Optional, Union -from ray.air.constants import MODEL_KEY from tensorflow.keras.callbacks import Callback as KerasCallback from ray.air import session -from ray.air.checkpoint import Checkpoint +from ray.train.tensorflow import TensorflowCheckpoint from ray.util.annotations import PublicAPI @@ -176,7 +175,7 @@ def _handle(self, logs: Dict, when: str = None): checkpoint = None if freq > 0 and self._counter[when] % freq == 0: - checkpoint = Checkpoint.from_dict({MODEL_KEY: self.model.get_weights()}) + checkpoint = TensorflowCheckpoint.from_model(self.model) if not self._metrics: report_dict = logs diff --git a/python/ray/air/tests/test_keras_callback.py b/python/ray/air/tests/test_keras_callback.py index 41629b05a918..19924fdb5dd6 100644 --- a/python/ray/air/tests/test_keras_callback.py +++ b/python/ray/air/tests/test_keras_callback.py @@ -4,10 +4,10 @@ import ray from ray.air import session from ray.air.callbacks.keras import Callback -from ray.air.constants import MODEL_KEY from ray.train.constants import TRAIN_DATASET_KEY from ray.air.config import ScalingConfig from ray.train.tensorflow import ( + TensorflowCheckpoint, TensorflowTrainer, prepare_dataset_shard, TensorflowPredictor, @@ -79,8 +79,8 @@ def test_keras_callback_e2e(): datasets={TRAIN_DATASET_KEY: get_dataset()}, ) checkpoint = trainer.fit().checkpoint - checkpoint_dict = checkpoint.to_dict() - assert MODEL_KEY in checkpoint_dict + assert isinstance(checkpoint, TensorflowCheckpoint) + assert checkpoint._flavor == TensorflowCheckpoint.Flavor.MODEL_WEIGHTS predictor = TensorflowPredictor.from_checkpoint( checkpoint, model_definition=build_model diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index 3b66d779fa2b..d5e03ba82b74 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -310,6 +310,14 @@ py_test( deps = [":train_lib"] ) +py_test( + name = "test_tensorflow_checkpoint", + size = "small", + srcs = ["tests/test_tensorflow_checkpoint.py"], + tags = ["team:ml", "exclusive"], + deps = [":train_lib"] +) + py_test( name = "test_tensorflow_predictor", size = "small", diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index a69fe967730b..62d62b63d05a 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -486,10 +486,10 @@ def _datasets_repr_(self) -> str: return VBox(content, layout=Layout(width="100%")) -def _load_checkpoint( +def _load_checkpoint_dict( checkpoint: Checkpoint, trainer_name: str ) -> Tuple[Any, Optional["Preprocessor"]]: - """Load a Ray Train Checkpoint. + """Load a Ray Train Checkpoint (dict based). This is a private API. diff --git a/python/ray/train/tensorflow/tensorflow_checkpoint.py b/python/ray/train/tensorflow/tensorflow_checkpoint.py index 2fc10ba4f9c3..a843f33c3f1a 100644 --- a/python/ray/train/tensorflow/tensorflow_checkpoint.py +++ b/python/ray/train/tensorflow/tensorflow_checkpoint.py @@ -1,11 +1,16 @@ -from typing import TYPE_CHECKING, Optional +import os +from typing import TYPE_CHECKING, Callable, Optional, Type, Union +from enum import Enum +from os import path import tensorflow as tf from tensorflow import keras +import warnings +from ray.air._internal.checkpointing import save_preprocessor_to_dir from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY -from ray.train.data_parallel_trainer import _load_checkpoint +from ray.train.data_parallel_trainer import _load_checkpoint_dict from ray.util.annotations import PublicAPI if TYPE_CHECKING: @@ -21,6 +26,25 @@ class TensorflowCheckpoint(Checkpoint): ``TensorflowCheckpoint.from_checkpoint(ckpt)``. """ + _SERIALIZED_ATTRS = ("_flavor", "_h5_file_path") + + class Flavor(Enum): + # Various flavors with which TensorflowCheckpoint is generated. + # This is necessary metadata to decide how to load model from a checkpoint. + MODEL_WEIGHTS = 1 + SAVED_MODEL = 2 + H5 = 3 + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._flavor = None + # Will only be set when `self._flavor` is `H5`. + self._h5_file_path = None + @classmethod def from_model( cls, @@ -31,8 +55,11 @@ def from_model( """Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras model. + The checkpoint created with this method needs to be paired with + `model_definition` when used. + Args: - model: The Keras model to store in the checkpoint. + model: The Keras model, whose weights are stored in the checkpoint. preprocessor: A fitted preprocessor to be applied before inference. Returns: @@ -48,9 +75,191 @@ def from_model( checkpoint = cls.from_dict( {PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()} ) + checkpoint._flavor = cls.Flavor.MODEL_WEIGHTS return checkpoint - def get_model_weights(self) -> tf.keras.Model: - """Retrieve the model weights stored in this checkpoint.""" - model_weights, _ = _load_checkpoint(self, "TensorflowTrainer") - return model_weights + @classmethod + def from_h5( + cls, file_path: str, *, preprocessor: Optional["Preprocessor"] = None + ) -> "TensorflowCheckpoint": + """Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras + model from H5 format. + + The checkpoint generated by this method contains all the information needed. + Thus no `model_definition` is needed to be supplied when using this checkpoint. + + `file_path` must maintain validity even after this function returns. + Some new files/directories may be added to the parent directory of `file_path`, + as a side effect of this method. + + Args: + file_path: The path to the .h5 file to load model from. This is the + same path that is used for ``model.save(path)``. + preprocessor: A fitted preprocessor to be applied before inference. + + Returns: + A :py:class:`TensorflowCheckpoint` converted from h5 format. + + Examples: + + >>> import tensorflow as tf + + >>> import ray + >>> from ray.train.batch_predictor import BatchPredictor + >>> from ray.train.tensorflow import ( + ... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor + ... ) + >>> from ray.air import session + >>> from ray.air.config import ScalingConfig + + >>> def train_func(): + ... model = tf.keras.Sequential( + ... [ + ... tf.keras.layers.InputLayer(input_shape=()), + ... tf.keras.layers.Flatten(), + ... tf.keras.layers.Dense(10), + ... tf.keras.layers.Dense(1), + ... ] + ... ) + ... model.save("my_model.h5") + ... checkpoint = TensorflowCheckpoint.from_h5("my_model.h5") + ... session.report({"my_metric": 1}, checkpoint=checkpoint) + + >>> trainer = TensorflowTrainer( + ... train_loop_per_worker=train_func, + ... scaling_config=ScalingConfig(num_workers=2)) + + >>> result_checkpoint = trainer.fit().checkpoint # doctest: +SKIP + + >>> batch_predictor = BatchPredictor.from_checkpoint( + ... result_checkpoint, TensorflowPredictor) # doctest: +SKIP + >>> batch_predictor.predict(ray.data.range(3)) # doctest: +SKIP + """ + if not path.isfile(file_path) or not file_path.endswith(".h5"): + raise ValueError( + "Please supply a h5 file path to `TensorflowCheckpoint.from_h5()`." + ) + dir_path = path.dirname(os.path.abspath(file_path)) + if preprocessor: + save_preprocessor_to_dir(preprocessor, dir_path) + checkpoint = cls.from_directory(dir_path) + checkpoint._flavor = cls.Flavor.H5 + checkpoint._h5_file_path = os.path.basename(file_path) + return checkpoint + + @classmethod + def from_saved_model( + cls, dir_path: str, *, preprocessor: Optional["Preprocessor"] = None + ) -> "TensorflowCheckpoint": + """Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Keras + model from SavedModel format. + + The checkpoint generated by this method contains all the information needed. + Thus no `model_definition` is needed to be supplied when using this checkpoint. + + `dir_path` must maintain validity even after this function returns. + Some new files/directories may be added to `dir_path`, as a side effect + of this method. + + Args: + dir_path: The directory containing the saved model. This is the same + directory as used by ``model.save(dir_path)``. + preprocessor: A fitted preprocessor to be applied before inference. + + Returns: + A :py:class:`TensorflowCheckpoint` converted from SavedModel format. + + Examples: + + >>> import tensorflow as tf + + >>> import ray + >>> from ray.train.batch_predictor import BatchPredictor + >>> from ray.train.tensorflow import ( + ... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor) + >>> from ray.air import session + >>> from ray.air.config import ScalingConfig + + >>> def train_fn(): + ... model = tf.keras.Sequential( + ... [ + ... tf.keras.layers.InputLayer(input_shape=()), + ... tf.keras.layers.Flatten(), + ... tf.keras.layers.Dense(10), + ... tf.keras.layers.Dense(1), + ... ]) + ... model.save("my_model") + ... checkpoint = TensorflowCheckpoint.from_saved_model("my_model") + ... session.report({"my_metric": 1}, checkpoint=checkpoint) + + >>> trainer = TensorflowTrainer( + ... train_loop_per_worker=train_fn, + ... scaling_config=ScalingConfig(num_workers=2)) + + >>> result_checkpoint = trainer.fit().checkpoint # doctest: +SKIP + + >>> batch_predictor = BatchPredictor.from_checkpoint( + ... result_checkpoint, TensorflowPredictor) # doctest: +SKIP + >>> batch_predictor.predict(ray.data.range(3)) # doctest: +SKIP + """ + if preprocessor: + save_preprocessor_to_dir(preprocessor, dir_path) + checkpoint = cls.from_directory(dir_path) + checkpoint._flavor = cls.Flavor.SAVED_MODEL + return checkpoint + + def get_model( + self, + model_definition: Optional[ + Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]] + ] = None, + ) -> tf.keras.Model: + """Retrieve the model stored in this checkpoint. + + Args: + model_definition: This arg is expected only if the original checkpoint + was created via `TensorflowCheckpoint.from_model`. + + Returns: + The Tensorflow Keras model stored in the checkpoint. + """ + if self._flavor is self.Flavor.MODEL_WEIGHTS: + if not model_definition: + raise ValueError( + "Expecting `model_definition` argument when checkpoint is " + "saved through `TensorflowCheckpoint.from_model()`." + ) + model_weights, _ = _load_checkpoint_dict(self, "TensorflowTrainer") + model = model_definition() + model.set_weights(model_weights) + return model + else: + if model_definition: + warnings.warn( + "TensorflowCheckpoint was created from " + "TensorflowCheckpoint.from_saved_model` or " + "`TensorflowCheckpoint.from_h5`, which already contains all the " + "information needed. This means: " + "If you are using BatchPredictor, you should do " + "`BatchPredictor.from_checkpoint(checkpoint, TensorflowPredictor)`" + " by removing kwargs `model_definition=`. " + "If you are using TensorflowPredictor directly, you should do " + "`TensorflowPredictor.from_checkpoint(checkpoint)` by " + "removing kwargs `model_definition=`." + ) + with self.as_directory() as checkpoint_dir: + if self._flavor == self.Flavor.H5: + return keras.models.load_model( + os.path.join(checkpoint_dir, self._h5_file_path) + ) + elif self._flavor == self.Flavor.SAVED_MODEL: + return keras.models.load_model(checkpoint_dir) + else: + raise RuntimeError( + "Avoid directly using `from_dict` or " + "`from_directory` directly. Make sure " + "that the checkpoint was generated by " + "`TensorflowCheckpoint.from_model`, " + "`TensorflowCheckpoint.from_saved_model` or " + "`TensorflowCheckpoint.from_h5`." + ) diff --git a/python/ray/train/tensorflow/tensorflow_predictor.py b/python/ray/train/tensorflow/tensorflow_predictor.py index 67e5950794be..6367b4694799 100644 --- a/python/ray/train/tensorflow/tensorflow_predictor.py +++ b/python/ray/train/tensorflow/tensorflow_predictor.py @@ -23,8 +23,7 @@ class TensorflowPredictor(DLPredictor): """A predictor for TensorFlow models. Args: - model_definition: A callable that returns a TensorFlow Keras model - to use for predictions. + model: A Tensorflow Keras model to use for predictions. preprocessor: A preprocessor used to transform data batches prior to prediction. model_weights: List of weights to use for the model. @@ -34,14 +33,11 @@ class TensorflowPredictor(DLPredictor): def __init__( self, - model_definition: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]], + *, + model: Optional[tf.keras.Model] = None, preprocessor: Optional["Preprocessor"] = None, - model_weights: Optional[list] = None, use_gpu: bool = False, ): - self.model_definition = model_definition - self.model_weights = model_weights - self.use_gpu = use_gpu # TensorFlow model objects cannot be pickled, therefore we use # a callable that returns the model and initialize it here, @@ -52,9 +48,9 @@ def __init__( if use_gpu: # TODO (jiaodong): #26249 Use multiple GPU devices with sharded input with tf.device("GPU:0"): - self._model = self.model_definition() + self._model = model else: - self._model = self.model_definition() + self._model = model gpu_devices = tf.config.list_physical_devices("GPU") if len(gpu_devices) > 0 and log_once("tf_predictor_not_using_gpu"): logger.warning( @@ -65,18 +61,17 @@ def __init__( "`batch_predictor.predict(ds, num_gpus_per_worker=1)` to " "enable GPU prediction." ) - - if model_weights is not None: - self._model.set_weights(model_weights) super().__init__(preprocessor) def __repr__(self): - fn_name = getattr(self.model_definition, "__name__", self.model_definition) + fn_name = getattr(self._model, "__name__", self._model) + fn_name_str = "" + if fn_name: + fn_name_str = str(fn_name)[:40] return ( f"{self.__class__.__name__}(" - f"model_definition={fn_name}, " + f"model={fn_name_str!r}, " f"preprocessor={self._preprocessor!r}, " - f"model_weights={self.model_weights!r}, " f"use_gpu={self.use_gpu!r})" ) @@ -84,8 +79,10 @@ def __repr__(self): def from_checkpoint( cls, checkpoint: Checkpoint, - model_definition: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]], - use_gpu: bool = False, + model_definition: Optional[ + Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]] + ] = None, + use_gpu: Optional[bool] = False, ) -> "TensorflowPredictor": """Instantiate the predictor from a Checkpoint. @@ -97,13 +94,15 @@ def from_checkpoint( ``TensorflowTrainer`` run. model_definition: A callable that returns a TensorFlow Keras model to use. Model weights will be loaded from the checkpoint. + This is only needed if the `checkpoint` was created from + `TensorflowCheckpoint.from_model`. + use_gpu: Whether GPU should be used during prediction. """ checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) - model_weights = checkpoint.get_model_weights() + model = checkpoint.get_model(model_definition) preprocessor = checkpoint.get_preprocessor() return cls( - model_definition=model_definition, - model_weights=model_weights, + model=model, preprocessor=preprocessor, use_gpu=use_gpu, ) @@ -192,8 +191,7 @@ def predict( ... ) >>> >>> weights = [np.array([[2.0]]), np.array([0.0])] - >>> predictor = TensorflowPredictor( - ... model_definition=build_model, model_weights=weights) + >>> predictor = TensorflowPredictor(model=build_model()) >>> >>> data = np.asarray([1, 2, 3]) >>> predictions = predictor.predict(data) # doctest: +SKIP @@ -210,7 +208,7 @@ def predict( ... return tf.keras.models.Model( ... inputs=[input1, input2], outputs=output) >>> - >>> predictor = TensorflowPredictor(model_definition=build_model) + >>> predictor = TensorflowPredictor(model=build_model()) >>> >>> # Pandas dataframe. >>> data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) diff --git a/python/ray/train/tests/test_tensorflow_checkpoint.py b/python/ray/train/tests/test_tensorflow_checkpoint.py new file mode 100644 index 000000000000..8b21b5fbcd62 --- /dev/null +++ b/python/ray/train/tests/test_tensorflow_checkpoint.py @@ -0,0 +1,172 @@ +from numpy import ndarray +import os.path +import pytest +import tempfile +import tensorflow as tf +from typing import List +import unittest + +import ray +from ray.train.batch_predictor import BatchPredictor +from ray.train.tensorflow import ( + TensorflowCheckpoint, + TensorflowTrainer, + TensorflowPredictor, +) +from ray.air import session +from ray.air.config import ScalingConfig +from ray.data import Preprocessor + + +class DummyPreprocessor(Preprocessor): + def __init__(self, multiplier): + self.multiplier = multiplier + + def transform_batch(self, df): + return df * self.multiplier + + +def get_model(): + return tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=()), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10), + tf.keras.layers.Dense(1), + ] + ) + + +def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool: + if not len(w1) == len(w2): + return False + size = len(w1) + for i in range(size): + comparison = w1[i] == w2[i] + if not comparison.all(): + return False + + return True + + +class TestFromModel(unittest.TestCase): + def setUp(self): + self.model = get_model() + self.preprocessor = DummyPreprocessor(1) + + def test_from_model(self): + checkpoint = TensorflowCheckpoint.from_model( + self.model, preprocessor=DummyPreprocessor(1) + ) + loaded_model = checkpoint.get_model(model_definition=get_model) + preprocessor = checkpoint.get_preprocessor() + + assert compare_weights(loaded_model.get_weights(), self.model.get_weights()) + assert preprocessor.multiplier == 1 + + def test_from_model_no_definition(self): + checkpoint = TensorflowCheckpoint.from_model( + self.model, preprocessor=self.preprocessor + ) + with self.assertRaises(ValueError): + checkpoint.get_model() + + def test_from_saved_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + model_dir_path = os.path.join(tmp_dir, "my_model") + self.model.save(model_dir_path) + checkpoint = TensorflowCheckpoint.from_saved_model( + model_dir_path, preprocessor=DummyPreprocessor(1) + ) + loaded_model = checkpoint.get_model() + preprocessor = checkpoint.get_preprocessor() + assert compare_weights(self.model.get_weights(), loaded_model.get_weights()) + assert preprocessor.multiplier == 1 + + def test_from_saved_model_warning_with_model_definition(self): + with tempfile.TemporaryDirectory() as tmp_dir: + model_dir_path = os.path.join(tmp_dir, "my_model") + self.model.save(model_dir_path) + checkpoint = TensorflowCheckpoint.from_saved_model( + model_dir_path, + preprocessor=DummyPreprocessor(1), + ) + with pytest.warns(None): + loaded_model = checkpoint.get_model(model_definition=get_model) + preprocessor = checkpoint.get_preprocessor() + assert compare_weights(self.model.get_weights(), loaded_model.get_weights()) + assert preprocessor.multiplier == 1 + + def test_from_h5_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + model_file_path = os.path.join(tmp_dir, "my_model.h5") + self.model.save(model_file_path) + checkpoint = TensorflowCheckpoint.from_h5( + model_file_path, preprocessor=DummyPreprocessor(1) + ) + loaded_model = checkpoint.get_model() + preprocessor = checkpoint.get_preprocessor() + assert compare_weights(self.model.get_weights(), loaded_model.get_weights()) + assert preprocessor.multiplier == 1 + + +def test_tensorflow_checkpoint_saved_model(): + # The test passes if the following can run successfully. + + def train_fn(): + model = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=()), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10), + tf.keras.layers.Dense(1), + ] + ) + model.save("my_model") + checkpoint = TensorflowCheckpoint.from_saved_model("my_model") + session.report({"my_metric": 1}, checkpoint=checkpoint) + + trainer = TensorflowTrainer( + train_loop_per_worker=train_fn, scaling_config=ScalingConfig(num_workers=2) + ) + + result_checkpoint = trainer.fit().checkpoint + + batch_predictor = BatchPredictor.from_checkpoint( + result_checkpoint, TensorflowPredictor + ) + batch_predictor.predict(ray.data.range(3)) + + +def test_tensorflow_checkpoint_h5(): + # The test passes if the following can run successfully. + + def train_func(): + model = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=()), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10), + tf.keras.layers.Dense(1), + ] + ) + model.save("my_model.h5") + checkpoint = TensorflowCheckpoint.from_h5("my_model.h5") + session.report({"my_metric": 1}, checkpoint=checkpoint) + + trainer = TensorflowTrainer( + train_loop_per_worker=train_func, scaling_config=ScalingConfig(num_workers=2) + ) + + result_checkpoint = trainer.fit().checkpoint + + batch_predictor = BatchPredictor.from_checkpoint( + result_checkpoint, TensorflowPredictor + ) + batch_predictor.predict(ray.data.range(3)) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/train/tests/test_tensorflow_predictor.py b/python/ray/train/tests/test_tensorflow_predictor.py index 375dd160422b..d87346b60e95 100644 --- a/python/ray/train/tests/test_tensorflow_predictor.py +++ b/python/ray/train/tests/test_tensorflow_predictor.py @@ -7,7 +7,7 @@ import ray from ray.air.checkpoint import Checkpoint -from ray.air.constants import MAX_REPR_LENGTH, MODEL_KEY, PREPROCESSOR_KEY +from ray.air.constants import MAX_REPR_LENGTH from ray.air.util.data_batch_conversion import ( convert_pandas_to_batch_type, convert_batch_type_to_pandas, @@ -21,7 +21,7 @@ from dummy_preprocessor import DummyPreprocessor -def build_model() -> tf.keras.Model: +def build_raw_model() -> tf.keras.Model: model = tf.keras.Sequential( [ tf.keras.layers.InputLayer(input_shape=()), @@ -33,6 +33,15 @@ def build_model() -> tf.keras.Model: return model +weights = [np.array([[2.0]]), np.array([0.0])] + + +def build_model() -> tf.keras.Model: + model = build_raw_model() + model.set_weights(weights) + return model + + def build_model_multi_input() -> tf.keras.Model: input1 = tf.keras.layers.Input(shape=(1,), name="A") input2 = tf.keras.layers.Input(shape=(1,), name="B") @@ -54,11 +63,8 @@ def build_model_unsupported() -> tf.keras.Model: return model -weights = [np.array([[2.0]]), np.array([0.0])] - - def test_repr(): - predictor = TensorflowPredictor(model_definition=build_model) + predictor = TensorflowPredictor(model=build_model()) representation = repr(predictor) @@ -69,8 +75,8 @@ def test_repr(): def create_checkpoint_preprocessor() -> Tuple[Checkpoint, Preprocessor]: preprocessor = DummyPreprocessor() - checkpoint = Checkpoint.from_dict( - {MODEL_KEY: weights, PREPROCESSOR_KEY: preprocessor} + checkpoint = TensorflowCheckpoint.from_model( + build_model(), preprocessor=preprocessor ) return checkpoint, preprocessor @@ -79,14 +85,13 @@ def create_checkpoint_preprocessor() -> Tuple[Checkpoint, Preprocessor]: def test_init(): checkpoint, preprocessor = create_checkpoint_preprocessor() - predictor = TensorflowPredictor( - model_definition=build_model, preprocessor=preprocessor, model_weights=weights - ) + predictor = TensorflowPredictor(model=build_model(), preprocessor=preprocessor) - checkpoint_predictor = TensorflowPredictor.from_checkpoint(checkpoint, build_model) + checkpoint_predictor = TensorflowPredictor.from_checkpoint( + checkpoint, model_definition=build_raw_model + ) - assert checkpoint_predictor.model_definition == predictor.model_definition - assert checkpoint_predictor.model_weights == predictor.model_weights + assert checkpoint_predictor._model.get_weights() == predictor._model.get_weights() assert checkpoint_predictor.get_preprocessor() == predictor.get_preprocessor() @@ -96,20 +101,24 @@ def test_tensorflow_checkpoint(): preprocessor = DummyPreprocessor() checkpoint = TensorflowCheckpoint.from_model(model, preprocessor=preprocessor) - assert checkpoint.get_model_weights() == model.get_weights() + assert ( + checkpoint.get_model(model_definition=build_raw_model).get_weights() + == model.get_weights() + ) with checkpoint.as_directory() as path: checkpoint = TensorflowCheckpoint.from_directory(path) checkpoint_preprocessor = checkpoint.get_preprocessor() - assert checkpoint.get_model_weights() == model.get_weights() + assert ( + checkpoint.get_model(model_definition=build_raw_model).get_weights() + == model.get_weights() + ) assert checkpoint_preprocessor == preprocessor @pytest.mark.parametrize("use_gpu", [False, True]) def test_predict_array(use_gpu): - predictor = TensorflowPredictor( - model_definition=build_model, model_weights=weights, use_gpu=use_gpu - ) + predictor = TensorflowPredictor(model=build_model(), use_gpu=use_gpu) data_batch = np.asarray([1, 2, 3]) predictions = predictor.predict(data_batch) @@ -122,9 +131,8 @@ def test_predict_array(use_gpu): def test_predict_array_with_preprocessor(use_gpu): preprocessor = DummyPreprocessor() predictor = TensorflowPredictor( - model_definition=build_model, + model=build_model(), preprocessor=preprocessor, - model_weights=weights, use_gpu=use_gpu, ) @@ -138,7 +146,7 @@ def test_predict_array_with_preprocessor(use_gpu): @pytest.mark.parametrize("batch_type", [np.ndarray, pd.DataFrame, pa.Table, dict]) def test_predict(batch_type): - predictor = TensorflowPredictor(model_definition=build_model_multi_input) + predictor = TensorflowPredictor(model=build_model_multi_input()) raw_batch = pd.DataFrame({"A": [0.0, 0.0, 0.0], "B": [1.0, 2.0, 3.0]}) data_batch = convert_pandas_to_batch_type(raw_batch, type=TYPE_TO_ENUM[batch_type]) @@ -151,7 +159,7 @@ def test_predict(batch_type): @pytest.mark.parametrize("batch_type", [pd.DataFrame, pa.Table]) def test_predict_batch(ray_start_4_cpus, batch_type): - checkpoint = TensorflowCheckpoint.from_dict({MODEL_KEY: {}}) + checkpoint = TensorflowCheckpoint.from_model(model=build_model_multi_input()) predictor = BatchPredictor.from_checkpoint( checkpoint, TensorflowPredictor, model_definition=build_model_multi_input ) @@ -176,9 +184,7 @@ def test_predict_batch(ray_start_4_cpus, batch_type): @pytest.mark.parametrize("use_gpu", [False, True]) def test_predict_dataframe(use_gpu): - predictor = TensorflowPredictor( - model_definition=build_model_multi_input, use_gpu=use_gpu - ) + predictor = TensorflowPredictor(model=build_model_multi_input(), use_gpu=use_gpu) data_batch = pd.DataFrame({"A": [0.0, 0.0, 0.0], "B": [1.0, 2.0, 3.0]}) predictions = predictor.predict(data_batch) @@ -189,9 +195,7 @@ def test_predict_dataframe(use_gpu): @pytest.mark.parametrize("use_gpu", [False, True]) def test_predict_multi_output(use_gpu): - predictor = TensorflowPredictor( - model_definition=build_model_multi_output, use_gpu=use_gpu - ) + predictor = TensorflowPredictor(model=build_model_multi_output(), use_gpu=use_gpu) data_batch = np.array([1, 2, 3]) predictions = predictor.predict(data_batch) @@ -206,7 +210,7 @@ def test_predict_multi_output(use_gpu): def test_predict_unsupported_output(): """Tests predictions with models that have unsupported output types.""" - predictor = TensorflowPredictor(model_definition=build_model_unsupported) + predictor = TensorflowPredictor(model=build_model_unsupported()) data_batch = np.array([1, 2, 3]) # Unsupported output should fail @@ -219,7 +223,7 @@ def call_model(self, tensor): model_output = super().call_model(tensor) return {str(i): model_output[i] for i in range(len(model_output))} - predictor = CustomPredictor(model_definition=build_model_unsupported) + predictor = CustomPredictor(model=build_model_unsupported()) predictions = predictor.predict(data_batch) # Model outputs two tensors diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index 55dbd91a74d9..50a3d8cb749e 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -5,14 +5,17 @@ import ray from ray.air import session -from ray.air.checkpoint import Checkpoint from ray.air.examples.tf.tensorflow_regression_example import ( get_dataset, train_func as tensorflow_linear_train_func, ) from ray.air.config import ScalingConfig -from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY -from ray.train.tensorflow import TensorflowPredictor, TensorflowTrainer +from ray.train.constants import TRAIN_DATASET_KEY +from ray.train.tensorflow import ( + TensorflowCheckpoint, + TensorflowPredictor, + TensorflowTrainer, +) @pytest.fixture @@ -66,8 +69,8 @@ def train_func(config): def test_tensorflow_e2e(ray_start_4_cpus): def train_func(): - model = build_model().get_weights() - session.report({}, checkpoint=Checkpoint.from_dict({MODEL_KEY: model})) + model = build_model() + session.report({}, checkpoint=TensorflowCheckpoint.from_model(model)) scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( @@ -101,9 +104,10 @@ def train_func(): else: model = build_model() - model.save("my_model", overwrite=True) + model.save("my_model") session.report( - metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model") + metrics={"iter": 1}, + checkpoint=TensorflowCheckpoint.from_saved_model("my_model"), ) scaling_config = ScalingConfig(num_workers=2) diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 3df889f709b6..774256965c8b 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -5,7 +5,7 @@ from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY -from ray.train.data_parallel_trainer import _load_checkpoint +from ray.train.data_parallel_trainer import _load_checkpoint_dict from ray.air._internal.torch_utils import load_torch_model from ray.util.annotations import PublicAPI @@ -106,7 +106,7 @@ def get_model(self, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: the model itself, then the state dict will be loaded to this ``model``. Otherwise, the model will be discarded. """ - saved_model, _ = _load_checkpoint(self, "TorchTrainer") + saved_model, _ = _load_checkpoint_dict(self, "TorchTrainer") if isinstance(saved_model, torch.nn.Module): if model: