Skip to content

Commit

Permalink
Merge pull request #31 from nasaharvest/model-device
Browse files Browse the repository at this point in the history
Move presto pretrained model to the right device
  • Loading branch information
gabrieltseng authored Dec 23, 2023
2 parents f6422ad + 6d8c6c9 commit 632cd39
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion presto/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,4 +789,4 @@ def construct_finetuning_model(
def load_pretrained(cls, model_path: Union[str, Path] = default_model_path):
model = cls.construct()
model.load_state_dict(torch.load(model_path, map_location=device))
return model
return model.to(device)

0 comments on commit 632cd39

Please sign in to comment.