From 6f61f7866932100239c3320703bd6de19a5dd5a7 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Jun 2023 11:21:41 -0700 Subject: [PATCH 1/6] Top k fix (#376) * jax ignore sorted in top_k * Ignore sorted argument for jax top_k `sorted=True` is a strictly stronger guarantee than `sorted=False`, so better to always return `sorted=True` than add an annoying inconsistency between what backends support what. --- keras_core/backend/jax/math.py | 6 ++---- keras_core/operations/math_test.py | 8 +++++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py index f1bd0b42a..53d8e0ca7 100644 --- a/keras_core/backend/jax/math.py +++ b/keras_core/backend/jax/math.py @@ -14,10 +14,8 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False): def top_k(x, k, sorted=True): - if not sorted: - return ValueError( - "Jax backend does not support `sorted=False` for `ops.top_k`" - ) + # Jax does not supported `sorted`, but in the case where `sorted=False`, + # order is not guaranteed, so OK to return sorted output. return jax.lax.top_k(x, k) diff --git a/keras_core/operations/math_test.py b/keras_core/operations/math_test.py index 4f55f132b..9f67da56b 100644 --- a/keras_core/operations/math_test.py +++ b/keras_core/operations/math_test.py @@ -20,7 +20,7 @@ def test_segment_sum(self): outputs = kmath.segment_sum(data, segment_ids, num_segments=5) self.assertEqual(outputs.shape, (5, 4)) - def test_topk(self): + def test_top_k(self): x = KerasTensor((None, 2, 3)) values, indices = kmath.top_k(x, k=1) self.assertEqual(values.shape, (None, 2, 1)) @@ -155,6 +155,12 @@ def test_top_k(self): self.assertAllClose(values, [4, 3]) self.assertAllClose(indices, [1, 4]) + x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32) + values, indices = kmath.top_k(x, k=2, sorted=False) + # Any order ok when `sorted=False`. + self.assertEqual(set(backend.convert_to_numpy(values)), set([4, 3])) + self.assertEqual(set(backend.convert_to_numpy(indices)), set([1, 4])) + x = np.random.rand(5, 5) outputs = kmath.top_k(x, k=2) expected = tf.math.top_k(x, k=2) From a85eb73252dda52b0aa4d9a122e2836cc961c512 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Jun 2023 14:46:57 -0700 Subject: [PATCH 2/6] Add keras_core.operations.convert_to_numpy (#378) For downstream use cases, nice to be able to have this in the ops layer, instead of being forced to reach into backend. --- keras_core/operations/__init__.py | 1 - keras_core/operations/core.py | 18 ++++++++++++++++++ keras_core/operations/core_test.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/keras_core/operations/__init__.py b/keras_core/operations/__init__.py index 5777ea576..2b9988d23 100644 --- a/keras_core/operations/__init__.py +++ b/keras_core/operations/__init__.py @@ -4,7 +4,6 @@ from keras_core.backend import cast from keras_core.backend import cond -from keras_core.backend import convert_to_tensor from keras_core.backend import is_tensor from keras_core.backend import name_scope from keras_core.backend import random diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index 1f613a2ad..e8d8efdd1 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -279,3 +279,21 @@ def cast(x, dtype): if any_symbolic_tensors((x,)): return backend.KerasTensor(shape=x.shape, dtype=dtype) return backend.core.cast(x, dtype) + + +@keras_core_export("keras_core.operations.convert_to_tensor") +def convert_to_tensor(x, dtype=None): + """Convert a NumPy array to a tensor.""" + return backend.convert_to_tensor(x, dtype=dtype) + + +@keras_core_export("keras_core.operations.convert_to_numpy") +def convert_to_numpy(x): + """Convert a tensor to a NumPy array.""" + if any_symbolic_tensors((x,)): + raise ValueError( + "A symbolic tensor (usually the result of applying layers or " + "operations to a `keras.Input`), cannot be converted to a numpy " + "array. There is no concrete value for the input." + ) + return backend.convert_to_numpy(x) diff --git a/keras_core/operations/core_test.py b/keras_core/operations/core_test.py index 3fed1bd2e..d4a2f98b6 100644 --- a/keras_core/operations/core_test.py +++ b/keras_core/operations/core_test.py @@ -231,3 +231,13 @@ def test_shape(self): x = KerasTensor((None, 3, None, 1)) self.assertAllEqual(core.shape(x), (None, 3, None, 1)) + + def test_convert_to_tensor(self): + x = np.ones((2,)) + x = ops.convert_to_tensor(x) + x = ops.convert_to_numpy(x) + self.assertAllEqual(x, (1, 1)) + self.assertIsInstance(x, np.ndarray) + + with self.assertRaises(ValueError): + ops.convert_to_numpy(KerasTensor((2,))) From 37dfbde962de5069cf6036668df6aa7fbc9d88ea Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 20 Jun 2023 15:12:55 -0700 Subject: [PATCH 3/6] Add explanation on torch device management (#379) --- keras_core/backend/torch/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index 8bbf85be5..c7449978e 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -1,3 +1,16 @@ +"""Torch backend APIs. + +Torch has a different logic of device management compared to TF and JAX. In +short variables/tensors are not by default created on GPU, and GPU cannot +directly communicate with CPU. Therefore, we are doing the following to automate +device management for Torch backend, if GPU is available: + +- Variables are created on GPU. +- Input data will be placed on GPU at the first `keras_core.layers.Layer` call. +- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU. +- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array. +""" + from keras_core.backend.torch import core from keras_core.backend.torch import image from keras_core.backend.torch import math From 62b2ae6f1f5594b4375b8b6bbc325f58cf472520 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 20 Jun 2023 15:16:03 -0700 Subject: [PATCH 4/6] Nits. --- keras_core/backend/torch/__init__.py | 13 ++++++++----- keras_core/operations/core.py | 5 +++++ keras_core/trainers/trainer.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index c7449978e..8bc9bddca 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -1,14 +1,17 @@ """Torch backend APIs. -Torch has a different logic of device management compared to TF and JAX. In -short variables/tensors are not by default created on GPU, and GPU cannot -directly communicate with CPU. Therefore, we are doing the following to automate -device management for Torch backend, if GPU is available: +# Note on device placement + +Torch has a different device placement style compared to TF and JAX. +In short, variables/tensors are not created on GPU by default, +and the GPU cannot directly communicate with the CPU. +To bring Torch behavior in line with TF and JAX automated device placement, +we are doing the following to automate device placement if a GPU is available: - Variables are created on GPU. - Input data will be placed on GPU at the first `keras_core.layers.Layer` call. - Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU. -- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array. +- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy. """ from keras_core.backend.torch import core diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index e8d8efdd1..bc66306d4 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -4,6 +4,11 @@ slice slice_update while_loop +stop_gradient +shape +cast +convert_to_tensor +convert_to_numpy """ from keras_core import backend diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 75447963e..5f8c29a0e 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -43,7 +43,7 @@ def compile( else: self._compile_metrics = None if jit_compile == "auto": - if model_supports_jit(self): + if not run_eagerly and model_supports_jit(self): jit_compile = True else: jit_compile = False From 4c435b24020870ade83794052c77191941eeb755 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 21 Jun 2023 00:10:55 +0000 Subject: [PATCH 5/6] Fix `keras_core.operations.Repeat` op's `compute_output_spec` method (#380) --- keras_core/operations/numpy.py | 5 ++++- keras_core/operations/numpy_test.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index cf6881b0e..2d412f4b7 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -2665,7 +2665,10 @@ def compute_output_spec(self, x): size_on_ax = x_shape[self.axis] output_shape = x_shape if isinstance(self.repeats, int): - output_shape[self.axis] = size_on_ax * self.repeats + if size_on_ax is None: + output_shape[self.axis] = None + else: + output_shape[self.axis] = size_on_ax * self.repeats else: output_shape[self.axis] = int(np.sum(self.repeats)) return KerasTensor(output_shape, dtype=x.dtype) diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 580e7674d..909769cab 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1027,6 +1027,7 @@ def test_repeat(self): self.assertEqual(knp.repeat(x, 2).shape, (None,)) self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9)) self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3)) + self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3)) def test_reshape(self): x = KerasTensor([None, 3]) From 9c575e676d28143b027a541248d68ffb6fd4dbf5 Mon Sep 17 00:00:00 2001 From: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:39:12 -0600 Subject: [PATCH 6/6] Update Torch ops.array (#375) * Update Torch ops.array This was previously inconsistent with Jax and TF, where ops.array produces an array/tensor of the native backend type. Instead this produced a NumPy array * Add ops test * Fix metrics tests * Update torch np.tile implementation --- keras_core/backend/torch/numpy.py | 7 +++++-- keras_core/metrics/metrics_utils.py | 2 +- keras_core/operations/numpy_test.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index b2afdf8ae..76849e140 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -4,6 +4,7 @@ from keras_core.backend.torch.core import cast from keras_core.backend.torch.core import convert_to_tensor from keras_core.backend.torch.core import get_device +from keras_core.backend.torch.core import is_tensor from keras_core.backend.torch.core import to_torch_dtype TORCH_INT_TYPES = ( @@ -187,9 +188,9 @@ def argsort(x, axis=-1): def array(x, dtype=None): dtype = to_torch_dtype(dtype) - if not isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): return x - return x.numpy() + return torch.tensor(x, dtype=dtype) def average(x, axis=None, weights=None): @@ -754,6 +755,8 @@ def round(x, decimals=0): def tile(x, repeats): + if is_tensor(repeats): + repeats = tuple(repeats.int().numpy()) x = convert_to_tensor(x) return torch.tile(x, dims=repeats) diff --git a/keras_core/metrics/metrics_utils.py b/keras_core/metrics/metrics_utils.py index 4c726f9a2..b3372c319 100644 --- a/keras_core/metrics/metrics_utils.py +++ b/keras_core/metrics/metrics_utils.py @@ -509,7 +509,7 @@ def update_confusion_matrix_variables( data_tiles = [num_thresholds, 1] thresh_tiled = ops.tile( - ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles) + ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles ) # Tile the predictions for every threshold. diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 909769cab..3d25bf5d5 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -2242,6 +2242,8 @@ def test_array(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.array(x), np.array(x)) self.assertAllClose(knp.Array()(x), np.array(x)) + self.assertTrue(backend.is_tensor(knp.array(x))) + self.assertTrue(backend.is_tensor(knp.Array()(x))) def test_average(self): x = np.array([[1, 2, 3], [3, 2, 1]])