Skip to content

Commit

Permalink
Update Torch ops.array (#375)
Browse files Browse the repository at this point in the history
* Update Torch ops.array

This was previously inconsistent with Jax and TF, where ops.array produces an array/tensor of the native backend type. Instead this produced a NumPy array

* Add ops test

* Fix metrics tests

* Update torch np.tile implementation
  • Loading branch information
ianstenbit authored Jun 21, 2023
1 parent 4c435b2 commit 9c575e6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
7 changes: 5 additions & 2 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import get_device
from keras_core.backend.torch.core import is_tensor
from keras_core.backend.torch.core import to_torch_dtype

TORCH_INT_TYPES = (
Expand Down Expand Up @@ -187,9 +188,9 @@ def argsort(x, axis=-1):

def array(x, dtype=None):
dtype = to_torch_dtype(dtype)
if not isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor):
return x
return x.numpy()
return torch.tensor(x, dtype=dtype)


def average(x, axis=None, weights=None):
Expand Down Expand Up @@ -754,6 +755,8 @@ def round(x, decimals=0):


def tile(x, repeats):
if is_tensor(repeats):
repeats = tuple(repeats.int().numpy())
x = convert_to_tensor(x)
return torch.tile(x, dims=repeats)

Expand Down
2 changes: 1 addition & 1 deletion keras_core/metrics/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def update_confusion_matrix_variables(
data_tiles = [num_thresholds, 1]

thresh_tiled = ops.tile(
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
)

# Tile the predictions for every threshold.
Expand Down
2 changes: 2 additions & 0 deletions keras_core/operations/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,8 @@ def test_array(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.array(x), np.array(x))
self.assertAllClose(knp.Array()(x), np.array(x))
self.assertTrue(backend.is_tensor(knp.array(x)))
self.assertTrue(backend.is_tensor(knp.Array()(x)))

def test_average(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down

0 comments on commit 9c575e6

Please sign in to comment.