Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air/tf] Support TensorflowCheckpoint's saved model/h5 format #28474

Merged

Conversation

xwjiang2010
Copy link
Contributor

@xwjiang2010 xwjiang2010 commented Sep 13, 2022

Signed-off-by: xwjiang2010 [email protected]

Why are these changes needed?

Support saved model and h5 formats.

This PR is dependent on landing #28935 first.

Related issue number

Closes #26933

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@xwjiang2010 xwjiang2010 changed the title [draft] Support TensorflowCheckpoint's support on saved model. [draft] Support TensorflowCheckpoint's saved model. Sep 13, 2022
@xwjiang2010 xwjiang2010 changed the title [draft] Support TensorflowCheckpoint's saved model. Support TensorflowCheckpoint's saved model/h5 format. Sep 29, 2022
Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xwjiang2010, left some comments.

Should we update our docs/examples to use saved model format?

python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_predictor.py Outdated Show resolved Hide resolved
Signed-off-by: xwjiang2010 <[email protected]>
python/ray/train/data_parallel_trainer.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_checkpoint.py Outdated Show resolved Hide resolved
krfricke pushed a commit that referenced this pull request Oct 3, 2022
Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`. 

`TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`.

These changes are needed to merge #28474.
@xwjiang2010
Copy link
Contributor Author

Thanks @xwjiang2010, left some comments.

Should we update our docs/examples to use saved model format?

Do you mean to enhance the examples with saved model as well? Or change it to just saved model?

@xwjiang2010
Copy link
Contributor Author

running another suite of unit tests...

Signed-off-by: xwjiang2010 <[email protected]>
super().__init__(preprocessor)

def __repr__(self):
fn_name = getattr(self.model_definition, "__name__", self.model_definition)
fn_name = getattr(self._model, "__name__", self._model)
fn_name_str = ""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amogkam @bveeramani let me know if this is OK. This is mostly to satisfy the length requirement.
It would be something like tf.keras.Sequential object with some cut-off.

)


def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed when ndarray has multiple elements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this just be return all([np.array_equal(i1, i2) for i1, i2 in zip(w1, w2)])?

@xwjiang2010
Copy link
Contributor Author

Thanks @xwjiang2010, left some comments.
Should we update our docs/examples to use saved model format?

Do you mean to enhance the examples with saved model as well? Or change it to just saved model?

@amogkam Actually looking at our current examples/docs, I also feel that it is not clear when to use framework specific checkpoint v.s. generic air checkpoint.

Additionally, thinking about MLflow etc, it is not clear what should be done if I want to use MLflow to manage all my checkpoints when I use say TensorflowCheckpoint. Need to think more here...

Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
@xwjiang2010 xwjiang2010 requested a review from a team as a code owner October 4, 2022 18:28
Signed-off-by: xwjiang2010 <[email protected]>
Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc lgtm

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xwjiang2010, lgtm! Just one minor nit

)


def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this just be return all([np.array_equal(i1, i2) for i1, i2 in zip(w1, w2)])?

@richardliaw richardliaw changed the title Support TensorflowCheckpoint's saved model/h5 format. [air/tf] Support TensorflowCheckpoint's saved model/h5 format Oct 5, 2022
@richardliaw richardliaw merged commit 0cc4b65 into ray-project:master Oct 5, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
)

Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`.

`TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`.

These changes are needed to merge ray-project#28474.

Signed-off-by: Weichen Xu <[email protected]>
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
@xwjiang2010 xwjiang2010 deleted the tf_predictor_from_saved_model branch July 26, 2023 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Ray AIR] Model needs to be specified twice for TensorFlowPredictor
4 participants