diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 28e63ce45b8eae..d5762337b50ac8 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -242,7 +242,7 @@ def is_jax_tensor(x): def _is_mlx(x): - import mx.core as mx + import mlx.core as mx return isinstance(x, mx.array)