Skip to content

Commit

Permalink
[train] Fix HuggingFace -> Transformers wrapping logic (ray-project#3…
Browse files Browse the repository at this point in the history
…5276)

Properly pass constructor arguments through.

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Marcus Zhang <[email protected]>
  • Loading branch information
matthewdeng authored and Marcus Zhang committed May 12, 2023
1 parent b907ed9 commit a72c1ee
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/huggingface/huggingface_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFaceCheckpoint(TransformersCheckpoint):
# than __init__
def __new__(cls: type, *args, **kwargs):
warnings.warn(deprecation_msg, DeprecationWarning)
return super(HuggingFaceCheckpoint, cls).__new__(cls)
return super(HuggingFaceCheckpoint, cls).__new__(cls, *args, **kwargs)


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/huggingface/huggingface_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFacePredictor(TransformersPredictor):
# than __init__
def __new__(cls: type, *args, **kwargs):
warnings.warn(deprecation_msg, DeprecationWarning)
return super(HuggingFacePredictor, cls).__new__(cls)
return super(HuggingFacePredictor, cls).__new__(cls, *args, **kwargs)


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/huggingface/huggingface_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFaceTrainer(TransformersTrainer):
# than __init__
def __new__(cls: type, *args, **kwargs):
warnings.warn(deprecation_msg, DeprecationWarning)
return super(HuggingFaceTrainer, cls).__new__(cls)
return super(HuggingFaceTrainer, cls).__new__(cls, *args, **kwargs)


__all__ = [
Expand Down

0 comments on commit a72c1ee

Please sign in to comment.