Skip to content

Commit

Permalink
Merge branch 'main' into aritra-np-backend
Browse files Browse the repository at this point in the history
OK
  • Loading branch information
ariG23498 committed Jun 21, 2023
2 parents 4586e33 + 9c575e6 commit 8373103
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 11 deletions.
6 changes: 2 additions & 4 deletions keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):


def top_k(x, k, sorted=True):
if not sorted:
return ValueError(
"Jax backend does not support `sorted=False` for `ops.top_k`"
)
# Jax does not supported `sorted`, but in the case where `sorted=False`,
# order is not guaranteed, so OK to return sorted output.
return jax.lax.top_k(x, k)


Expand Down
16 changes: 16 additions & 0 deletions keras_core/backend/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""Torch backend APIs.
# Note on device placement
Torch has a different device placement style compared to TF and JAX.
In short, variables/tensors are not created on GPU by default,
and the GPU cannot directly communicate with the CPU.
To bring Torch behavior in line with TF and JAX automated device placement,
we are doing the following to automate device placement if a GPU is available:
- Variables are created on GPU.
- Input data will be placed on GPU at the first `keras_core.layers.Layer` call.
- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU.
- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy.
"""

from keras_core.backend.torch import core
from keras_core.backend.torch import image
from keras_core.backend.torch import math
Expand Down
7 changes: 5 additions & 2 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import get_device
from keras_core.backend.torch.core import is_tensor
from keras_core.backend.torch.core import to_torch_dtype

TORCH_INT_TYPES = (
Expand Down Expand Up @@ -187,9 +188,9 @@ def argsort(x, axis=-1):

def array(x, dtype=None):
dtype = to_torch_dtype(dtype)
if not isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor):
return x
return x.numpy()
return torch.tensor(x, dtype=dtype)


def average(x, axis=None, weights=None):
Expand Down Expand Up @@ -754,6 +755,8 @@ def round(x, decimals=0):


def tile(x, repeats):
if is_tensor(repeats):
repeats = tuple(repeats.int().numpy())
x = convert_to_tensor(x)
return torch.tile(x, dims=repeats)

Expand Down
2 changes: 1 addition & 1 deletion keras_core/metrics/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def update_confusion_matrix_variables(
data_tiles = [num_thresholds, 1]

thresh_tiled = ops.tile(
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
)

# Tile the predictions for every threshold.
Expand Down
1 change: 0 additions & 1 deletion keras_core/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from keras_core.backend import cast
from keras_core.backend import cond
from keras_core.backend import convert_to_tensor
from keras_core.backend import is_tensor
from keras_core.backend import name_scope
from keras_core.backend import random
Expand Down
23 changes: 23 additions & 0 deletions keras_core/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
slice
slice_update
while_loop
stop_gradient
shape
cast
convert_to_tensor
convert_to_numpy
"""

from keras_core import backend
Expand Down Expand Up @@ -279,3 +284,21 @@ def cast(x, dtype):
if any_symbolic_tensors((x,)):
return backend.KerasTensor(shape=x.shape, dtype=dtype)
return backend.core.cast(x, dtype)


@keras_core_export("keras_core.operations.convert_to_tensor")
def convert_to_tensor(x, dtype=None):
"""Convert a NumPy array to a tensor."""
return backend.convert_to_tensor(x, dtype=dtype)


@keras_core_export("keras_core.operations.convert_to_numpy")
def convert_to_numpy(x):
"""Convert a tensor to a NumPy array."""
if any_symbolic_tensors((x,)):
raise ValueError(
"A symbolic tensor (usually the result of applying layers or "
"operations to a `keras.Input`), cannot be converted to a numpy "
"array. There is no concrete value for the input."
)
return backend.convert_to_numpy(x)
10 changes: 10 additions & 0 deletions keras_core/operations/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,13 @@ def test_shape(self):

x = KerasTensor((None, 3, None, 1))
self.assertAllEqual(core.shape(x), (None, 3, None, 1))

def test_convert_to_tensor(self):
x = np.ones((2,))
x = ops.convert_to_tensor(x)
x = ops.convert_to_numpy(x)
self.assertAllEqual(x, (1, 1))
self.assertIsInstance(x, np.ndarray)

with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,)))
8 changes: 7 additions & 1 deletion keras_core/operations/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_segment_sum(self):
outputs = kmath.segment_sum(data, segment_ids, num_segments=5)
self.assertEqual(outputs.shape, (5, 4))

def test_topk(self):
def test_top_k(self):
x = KerasTensor((None, 2, 3))
values, indices = kmath.top_k(x, k=1)
self.assertEqual(values.shape, (None, 2, 1))
Expand Down Expand Up @@ -155,6 +155,12 @@ def test_top_k(self):
self.assertAllClose(values, [4, 3])
self.assertAllClose(indices, [1, 4])

x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32)
values, indices = kmath.top_k(x, k=2, sorted=False)
# Any order ok when `sorted=False`.
self.assertEqual(set(backend.convert_to_numpy(values)), set([4, 3]))
self.assertEqual(set(backend.convert_to_numpy(indices)), set([1, 4]))

x = np.random.rand(5, 5)
outputs = kmath.top_k(x, k=2)
expected = tf.math.top_k(x, k=2)
Expand Down
5 changes: 4 additions & 1 deletion keras_core/operations/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,7 +2665,10 @@ def compute_output_spec(self, x):
size_on_ax = x_shape[self.axis]
output_shape = x_shape
if isinstance(self.repeats, int):
output_shape[self.axis] = size_on_ax * self.repeats
if size_on_ax is None:
output_shape[self.axis] = None
else:
output_shape[self.axis] = size_on_ax * self.repeats
else:
output_shape[self.axis] = int(np.sum(self.repeats))
return KerasTensor(output_shape, dtype=x.dtype)
Expand Down
3 changes: 3 additions & 0 deletions keras_core/operations/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ def test_repeat(self):
self.assertEqual(knp.repeat(x, 2).shape, (None,))
self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9))
self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3))
self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3))

def test_reshape(self):
x = KerasTensor([None, 3])
Expand Down Expand Up @@ -2241,6 +2242,8 @@ def test_array(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.array(x), np.array(x))
self.assertAllClose(knp.Array()(x), np.array(x))
self.assertTrue(backend.is_tensor(knp.array(x)))
self.assertTrue(backend.is_tensor(knp.Array()(x)))

def test_average(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down
2 changes: 1 addition & 1 deletion keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def compile(
else:
self._compile_metrics = None
if jit_compile == "auto":
if model_supports_jit(self):
if not run_eagerly and model_supports_jit(self):
jit_compile = True
else:
jit_compile = False
Expand Down

0 comments on commit 8373103

Please sign in to comment.