Skip to content

Commit

Permalink
Merge branch 'main' of github.com:keras-team/keras-core
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 26, 2023
2 parents 6818e66 + 71b02f4 commit 8a6da7e
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 33 deletions.
7 changes: 7 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
# import to happen first for all tests.
import torch # noqa: F401
except ImportError:
pass
2 changes: 1 addition & 1 deletion keras_core/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _initialize(self, value):
)

def _direct_assign(self, value):
self.value.assign(value)
self._value.assign(tf.cast(value, self._value.dtype))

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def convert_to_tensor(x, dtype=None):
# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
elif len(x) > 0 and isinstance(x[0], torch.Tensor):
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
# Handle list or tuple of torch tensors
return torch.stack(x)
return torch.stack([convert_to_tensor(x1) for x1 in x])
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
# Torch backend does not support uint32.
x = x.astype(np.int64)
Expand Down
4 changes: 4 additions & 0 deletions keras_core/operations/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,9 @@ def test_convert_to_tensor(self):
self.assertAllEqual(x, (1, 1))
self.assertIsInstance(x, np.ndarray)

# Partially converted.
x = ops.convert_to_tensor((1, ops.array(2), 3))
self.assertAllEqual(x, (1, 2, 3))

with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,)))
7 changes: 1 addition & 6 deletions keras_core/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import types
import warnings

import jax
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -162,11 +161,7 @@ def serialize_keras_object(obj):
}
if isinstance(obj, tf.TensorShape):
return obj.as_list() if obj._dims is not None else None
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)) or hasattr(
obj, "device"
):
# Import torch creates circular dependency, so we use
# `hasattr(obj, "device")` to check if obj is a torch tensor.
if backend.is_tensor(obj):
return {
"class_name": "__tensor__",
"config": {
Expand Down
39 changes: 29 additions & 10 deletions keras_core/trainers/data_adapters/array_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def __init__(
shuffle=False,
class_weight=None,
):
types_struct = nest.map_structure(lambda x: type(x), x)
flat_types = nest.flatten(types_struct)
if not all(
issubclass(c, data_adapter_utils.ARRAY_TYPES) for c in flat_types
):
if not can_convert_arrays((x, y, sample_weight)):
raise ValueError(
"Expected all elements of `x` to be array-like. "
f"Received invalid types: x={x}"
Expand Down Expand Up @@ -252,6 +248,28 @@ def partial_batch_size(self):
return self._partial_batch_size or None


def can_convert_arrays(arrays):
"""Check if array like-inputs can be handled by `ArrayDataAdapter`
Args:
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
Returns:
`True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
otherwise.
"""

def can_convert_single_array(x):
is_none = x is None
known_type = isinstance(x, data_adapter_utils.ARRAY_TYPES)
convertable_type = hasattr(x, "__array__")
return is_none or known_type or convertable_type

return all(
tf.nest.flatten(tf.nest.map_structure(can_convert_single_array, arrays))
)


def convert_to_arrays(arrays, dtype=None):
"""Process array-like inputs.
Expand All @@ -262,7 +280,7 @@ def convert_to_arrays(arrays, dtype=None):
- Converts `list`s to `tuple`s (for `tf.data` support).
Args:
inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
Returns:
Structure of NumPy `ndarray`s.
Expand All @@ -277,15 +295,16 @@ def convert_single_array(x):
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
elif isinstance(x, pandas.DataFrame):
x = x.to_numpy(dtype=dtype)
if isinstance(x, (tf.Tensor, tf.Variable)):
x = x.numpy()
if not isinstance(x, np.ndarray):
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
# `torch.Tensor`, as well as any other tensor-like object that has
# added numpy support.
if hasattr(x, "__array__"):
x = np.array(x, dtype=dtype)
else:
raise ValueError(
"Expected a NumPy array, tf.Tensor, "
"Pandas Dataframe, or Pandas Series. "
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, "
"torch.Tensor, Pandas Dataframe, or Pandas Series. "
f"Received invalid input: {x} (of type {type(x)})"
)
if x.dtype == object:
Expand Down
8 changes: 7 additions & 1 deletion keras_core/trainers/data_adapters/array_data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas
import tensorflow as tf
import torch
from absl.testing import parameterized

from keras_core import backend
Expand All @@ -10,7 +11,9 @@


class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
@parameterized.parameters([("np",), ("tf",), ("pandas")])
@parameterized.parameters(
[("np",), ("tf",), ("jax",), ("torch",), ("pandas")]
)
def test_basic_flow(self, array_type):
if array_type == "np":
x = np.random.random((34, 4))
Expand All @@ -21,6 +24,9 @@ def test_basic_flow(self, array_type):
elif array_type == "jax":
x = jax.numpy.ones((34, 4))
y = jax.numpy.ones((34, 2))
elif array_type == "torch":
x = torch.ones((34, 4))
y = torch.ones((34, 2))
elif array_type == "pandas":
x = pandas.DataFrame(np.random.random((34, 4)))
y = pandas.DataFrame(np.random.random((34, 2)))
Expand Down
14 changes: 5 additions & 9 deletions keras_core/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math

import jax
import numpy as np
import tensorflow as tf

Expand All @@ -12,15 +11,12 @@
pandas = None


ARRAY_TYPES = (tf.Tensor, np.ndarray, jax.numpy.ndarray)
# Leave jax, tf, and torch arrays off this list. Instead we will use
# `__array__` to detect these types. Doing so allows us to avoid importing a
# backend framework we are not currently using just to do type-checking.
ARRAY_TYPES = (np.ndarray,)
if pandas:
ARRAY_TYPES = ARRAY_TYPES + (
tf.Tensor,
np.ndarray,
pandas.Series,
pandas.DataFrame,
)
# TODO: support torch tensors?
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)


@keras_core_export("keras_core.utils.unpack_x_y_sample_weight")
Expand Down
5 changes: 1 addition & 4 deletions keras_core/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@
import warnings

import tensorflow as tf
from tensorflow import nest

from keras_core.trainers.data_adapters import array_data_adapter
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.data_adapters import generator_data_adapter
from keras_core.trainers.data_adapters import py_dataset_adapter
from keras_core.trainers.data_adapters import tf_dataset_adapter
Expand All @@ -69,8 +67,7 @@ def __init__(
if steps_per_epoch:
self._current_iterator = None
self._insufficient_data = False
first_element = next(iter(nest.flatten(x)), None)
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
if array_data_adapter.can_convert_arrays((x, y, sample_weight)):
self.data_adapter = array_data_adapter.ArrayDataAdapter(
x,
y,
Expand Down

0 comments on commit 8a6da7e

Please sign in to comment.