Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Skip checkpoint cast if checkpoint is same type #28935

Merged

Conversation

bveeramani
Copy link
Member

@bveeramani bveeramani commented Sep 30, 2022

Signed-off-by: Balaji Veeramani [email protected]

Why are these changes needed?

Checkpoints attributes are reset when passed to TensorflowPredictor.from_checkpoint.

TensorflowPredictor.from_checkpoint calls TensorflowCheckpoint.from_checkpoint, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a TensorflowCheckpoint.

These changes are needed to merge #28474.

Related issue number

See #28474 and see #26777

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@bveeramani bveeramani changed the title [AIR] Skip checkpoint cast if checkpoint is correct type [AIR] Skip checkpoint cast if checkpoint is same type Sep 30, 2022
@@ -467,6 +467,9 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
>>> model = checkpoint.get_model() # doctest: +SKIP
Linear(in_features=1, out_features=1, bias=True)
"""
if type(other) is cls:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this behavior makes sense to me but I need another pair of eyes with full checkpoint context. @krfricke and @xwjiang2010 ?

Copy link
Contributor

@krfricke krfricke Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main question here is if we should be less strict by using isinstance(other, cls), which will also return True for inherited classes.

@matthewdeng
Copy link
Contributor

Can you explain why it's created with default attributes? Is that not a bug?

@bveeramani
Copy link
Member Author

bveeramani commented Sep 30, 2022

Can you explain why it's created with default attributes? Is that not a bug?

It's sort-of a bug. More so that the behavior of Checkpoint.from_checkpoint is poorly defined.

TensorflowPredictor.from_checkpoint casts the input checkpoint to a TensorflowCheckpoint by calling TensorflowCheckpoint.from_checkpoint. This is necessary to avoid the errors described in #28134.

checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint)
model_weights = checkpoint.get_model_weights()
preprocessor = checkpoint.get_preprocessor()

Attributes are restored at the from_dict / from_directory layer of abstraction. Because TensorflowCheckpoint.from_checkpoint creates a new checkpoint with the constructor as opposed to from_dict / from_directory, the attributes aren't restored.

T = TypeVar(bound=Checkpoint)

def from_checkpoint(cls: T, other: Checkpoint) -> T:
    return cls(
        local_path=other._local_path,
        data_dict=other._data_dict,
        uri=other._uri,
        obj_ref=other._obj_ref,
    )

An alternative implementation would be to call from_* methods in from_checkpoint.

def from_checkpoint(cls: T, other: Checkpoint) -> T:
    if other._data_dict:
        return cls.from_dict(other._data_dict)
    ...

The problem with this approach is that we'd also restore the checkpoint type.

>>> checkpoint: TorchCheckpoint = ...
>>> new_checkpoint = Checkpoint.from_checkpoint(checkpoint)
>>> type(new_checkpoint). # Arguably should be `Checkpoint`
TorchCheckpoint   

In any case, from_checkpoint is a short-term hack. We should handle it soon, either by removing it or by making it an implementation detail

@matthewdeng
Copy link
Contributor

Ah, thanks for the explanation.

In any case, from_checkpoint is a short-term hack. We should handle it soon, either by removing it or by making it an implementation detail

Can this be documented either as a Github issue and/or in the code as a TODO? It would be great to have this context captured so we (a) know how to design a better long-term fix, and (b) can reasonably evaluate short-term fixes (e.g. this PR).

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, quick question before merge

@@ -467,6 +467,9 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
>>> model = checkpoint.get_model() # doctest: +SKIP
Linear(in_features=1, out_features=1, bias=True)
"""
if type(other) is cls:
Copy link
Contributor

@krfricke krfricke Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main question here is if we should be less strict by using isinstance(other, cls), which will also return True for inherited classes.

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about it some more @bveeramani, do we even need to use from_checkpoint in predictors?

The predictors can used the passed in checkpoints directly without needing to call from_checkpoint.

@bveeramani
Copy link
Member Author

Main question here is if we should be less strict by using isinstance(other, cls), which will also return True for inherited classes.

@krfricke decision was somewhat arbitrary, but the motivation is to make cls.from_checkpoint always return a cls.

Currently, if you do

>>> torch_checkpoint = TorchCheckpoint.from_state_dict(...)
>>> checkpoint = Checkpoint.from_checkpoint(torch_checkpoint)

you'd get back a Checkpoint

>>> type(checkpoint)
<class 'Checkpoint'>

but if we did isinstance(other, cls), you'd call Checkpoint.from_checkpoint but get a TorchCheckpoint back.

>>> type(checkpoint>
<class 'TorchCheckpoint'>

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about it some more @bveeramani, do we even need to use from_checkpoint in predictors?

Seems like this is still necessary until #28910 is fixed.

This PR looks good as a short term fix

@bveeramani
Copy link
Member Author

Actually thinking about it some more @bveeramani, do we even need to use from_checkpoint in predictors?

The predictors can used the passed in checkpoints directly without needing to call from_checkpoint.

If we want to support TorchPredictor.from_checkpoint(checkpoint: Checkpoint) (as opposed to TorchPredictor.from_checkpoint(checkpoint: TorchCheckpoint)), then we need to call Checkpoint.from_checkpoint in predictors.

But yeah, in any case, we should try to unblock #28474

@krfricke krfricke merged commit e705f03 into ray-project:master Oct 3, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
)

Checkpoints attributes are reset when passed to `TensorflowPredictor.from_checkpoint`.

`TensorflowPredictor.from_checkpoint` calls `TensorflowCheckpoint.from_checkpoint`, which creates a new checkpoint with default object attributes -- even if the passed-in checkpoint is already a `TensorflowCheckpoint`.

These changes are needed to merge ray-project#28474.

Signed-off-by: Weichen Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants