Skip to content

Commit

Permalink
Add mlx support to BatchEncoding.convert_to_tensors (#29406)
Browse files Browse the repository at this point in the history
* Add mlx support

* Fix import order and use def instead of lambda

* Another fix for ruff format :)

* Add detecting mlx from repr, add is_mlx_array
  • Loading branch information
tidely authored Mar 4, 2024
1 parent 39ef3fb commit 704b3f7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
11 changes: 11 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
extract_commit_hash,
is_flax_available,
is_jax_tensor,
is_mlx_available,
is_numpy_array,
is_offline_mode,
is_remote_url,
Expand Down Expand Up @@ -726,6 +727,16 @@ def as_tensor(value, dtype=None):

as_tensor = jnp.array
is_tensor = is_jax_tensor

elif tensor_type == TensorType.MLX:
if not is_mlx_available():
raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
import mlx.core as mx

as_tensor = mx.array

def is_tensor(obj):
return isinstance(obj, mx.array)
else:

def as_tensor(value, dtype=None):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_mlx_available,
is_natten_available,
is_ninja_available,
is_nltk_available,
Expand Down
30 changes: 27 additions & 3 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
import numpy as np
from packaging import version

from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
from .import_utils import (
get_torch_version,
is_flax_available,
is_mlx_available,
is_tf_available,
is_torch_available,
is_torch_fx_proxy,
)


if is_flax_available():
Expand Down Expand Up @@ -87,6 +94,8 @@ def infer_framework_from_repr(x):
return "jax"
elif representation.startswith("<class 'numpy."):
return "np"
elif representation.startswith("<class 'mlx."):
return "mlx"


def _get_frameworks_and_test_func(x):
Expand All @@ -99,6 +108,7 @@ def _get_frameworks_and_test_func(x):
"tf": is_tf_tensor,
"jax": is_jax_tensor,
"np": is_numpy_array,
"mlx": is_mlx_array,
}
preferred_framework = infer_framework_from_repr(x)
# We will test this one first, then numpy, then the others.
Expand All @@ -111,8 +121,8 @@ def _get_frameworks_and_test_func(x):

def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order
defined by `infer_framework_from_repr`
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array`
in the order defined by `infer_framework_from_repr`
"""
# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(x)
Expand Down Expand Up @@ -231,6 +241,19 @@ def is_jax_tensor(x):
return False if not is_flax_available() else _is_jax(x)


def _is_mlx(x):
import mx.core as mx

return isinstance(x, mx.array)


def is_mlx_array(x):
"""
Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
"""
return False if not is_mlx_available() else _is_mlx(x)


def to_py_obj(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
Expand Down Expand Up @@ -499,6 +522,7 @@ class TensorType(ExplicitEnum):
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
MLX = "mlx"


class ContextManagers:
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")


_torch_version = "N/A"
Expand Down Expand Up @@ -923,6 +924,10 @@ def is_jinja_available():
return _jinja_available


def is_mlx_available():
return _mlx_available


# docstyle-ignore
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
Expand Down

0 comments on commit 704b3f7

Please sign in to comment.