Skip to content

Commit

Permalink
Fix the deprecation warning of _torch_pytree._register_pytree_node (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever authored Dec 17, 2023
1 parent f85a1e8 commit e6dcf8a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init_subclass__(cls) -> None:
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
_torch_pytree._register_pytree_node(
torch_pytree_register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
Expand Down Expand Up @@ -438,7 +438,11 @@ def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Conte
output_type, keys = context
return output_type(**dict(zip(keys, values)))

_torch_pytree._register_pytree_node(
if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
Expand Down

0 comments on commit e6dcf8a

Please sign in to comment.