-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Skip checkpoint cast if checkpoint is same type #28935
[AIR] Skip checkpoint cast if checkpoint is same type #28935
Conversation
@@ -467,6 +467,9 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": | |||
>>> model = checkpoint.get_model() # doctest: +SKIP | |||
Linear(in_features=1, out_features=1, bias=True) | |||
""" | |||
if type(other) is cls: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this behavior makes sense to me but I need another pair of eyes with full checkpoint context. @krfricke and @xwjiang2010 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main question here is if we should be less strict by using isinstance(other, cls)
, which will also return True for inherited classes.
Can you explain why it's created with default attributes? Is that not a bug? |
It's sort-of a bug. More so that the behavior of
Attributes are restored at the
An alternative implementation would be to call
The problem with this approach is that we'd also restore the checkpoint type.
In any case, |
Ah, thanks for the explanation.
Can this be documented either as a Github issue and/or in the code as a TODO? It would be great to have this context captured so we (a) know how to design a better long-term fix, and (b) can reasonably evaluate short-term fixes (e.g. this PR). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, quick question before merge
@@ -467,6 +467,9 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": | |||
>>> model = checkpoint.get_model() # doctest: +SKIP | |||
Linear(in_features=1, out_features=1, bias=True) | |||
""" | |||
if type(other) is cls: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main question here is if we should be less strict by using isinstance(other, cls)
, which will also return True for inherited classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thinking about it some more @bveeramani, do we even need to use from_checkpoint
in predictors?
The predictors can used the passed in checkpoints directly without needing to call from_checkpoint
.
@krfricke decision was somewhat arbitrary, but the motivation is to make Currently, if you do
you'd get back a
but if we did
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thinking about it some more @bveeramani, do we even need to use from_checkpoint in predictors?
Seems like this is still necessary until #28910 is fixed.
This PR looks good as a short term fix
If we want to support But yeah, in any case, we should try to unblock #28474 |
) Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`. `TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`. These changes are needed to merge ray-project#28474. Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Balaji Veeramani [email protected]
Why are these changes needed?
Checkpoints attributes are reset when passed to
TensorflowPredictor.from_checkpoint
.TensorflowPredictor.from_checkpoint
callsTensorflowCheckpoint.from_checkpoint
, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already aTensorflowCheckpoint
.These changes are needed to merge #28474.
Related issue number
See #28474 and see #26777
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.