-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air/tf] Support TensorflowCheckpoint's saved model/h5 format #28474
[air/tf] Support TensorflowCheckpoint's saved model/h5 format #28474
Conversation
Signed-off-by: xwjiang2010 <[email protected]>
…redictor_from_saved_model
Signed-off-by: xwjiang2010 <[email protected]>
38a77e3
to
a3177f1
Compare
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
…redictor_from_saved_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @xwjiang2010, left some comments.
Should we update our docs/examples to use saved model format?
…redictor_from_saved_model
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
e8523f3
to
b8db76e
Compare
Signed-off-by: xwjiang2010 <[email protected]>
Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`. `TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`. These changes are needed to merge #28474.
Signed-off-by: xwjiang2010 <[email protected]>
…redictor_from_saved_model
Signed-off-by: xwjiang2010 <[email protected]>
Do you mean to enhance the examples with saved model as well? Or change it to just saved model? |
running another suite of unit tests... |
Signed-off-by: xwjiang2010 <[email protected]>
super().__init__(preprocessor) | ||
|
||
def __repr__(self): | ||
fn_name = getattr(self.model_definition, "__name__", self.model_definition) | ||
fn_name = getattr(self._model, "__name__", self._model) | ||
fn_name_str = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amogkam @bveeramani let me know if this is OK. This is mostly to satisfy the length requirement.
It would be something like tf.keras.Sequential object
with some cut-off.
) | ||
|
||
|
||
def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed when ndarray has multiple elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this just be return all([np.array_equal(i1, i2) for i1, i2 in zip(w1, w2)])
?
@amogkam Actually looking at our current examples/docs, I also feel that it is not clear when to use framework specific checkpoint v.s. generic air checkpoint. Additionally, thinking about MLflow etc, it is not clear what should be done if I want to use MLflow to manage all my checkpoints when I use say TensorflowCheckpoint. Need to think more here... |
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
Signed-off-by: xwjiang2010 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @xwjiang2010, lgtm! Just one minor nit
) | ||
|
||
|
||
def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this just be return all([np.array_equal(i1, i2) for i1, i2 in zip(w1, w2)])
?
) Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`. `TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`. These changes are needed to merge ray-project#28474. Signed-off-by: Weichen Xu <[email protected]>
…oject#28474) Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: xwjiang2010 [email protected]
Why are these changes needed?
Support saved model and h5 formats.
This PR is dependent on landing #28935 first.
Related issue number
Closes #26933
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.