Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some dtype fixes #935

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@


def arange(start, stop=None, step=1, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype

Check warning on line 118 in keras_core/backend/jax/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/numpy.py#L118

Added line #L118 was not covered by tests
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return jnp.arange(start, stop, step=step, dtype=dtype)


Expand Down
18 changes: 17 additions & 1 deletion keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from keras_core.backend import config
from keras_core.backend import standardize_dtype


def add(x1, x2):
return np.add(x1, x2)
Expand Down Expand Up @@ -77,6 +80,13 @@


def arange(start, stop=None, step=None, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype

Check warning on line 85 in keras_core/backend/numpy/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/numpy/numpy.py#L85

Added line #L85 was not covered by tests
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return np.arange(start, stop, step=step, dtype=dtype)


Expand Down Expand Up @@ -124,6 +134,7 @@


def array(x, dtype=None):
dtype = dtype or config.floatx()
return np.array(x, dtype=dtype)


Expand Down Expand Up @@ -271,6 +282,7 @@


def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return np.full(shape, fill_value, dtype=dtype)


Expand Down Expand Up @@ -592,7 +604,11 @@


def sqrt(x):
return np.sqrt(x)
dtype = None
if hasattr(x, "dtype"):
if standardize_dtype(x.dtype).startswith("int"):
dtype = config.floatx()
return np.sqrt(x, dtype=dtype)


def squeeze(x, axis=None):
Expand Down
11 changes: 11 additions & 0 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp

from keras_core.backend import config
from keras_core.backend.tensorflow.core import convert_to_tensor


Expand Down Expand Up @@ -176,6 +177,13 @@
def arange(start, stop=None, step=1, dtype=None):
# tfnp.arange has trouble with dynamic Tensors in compiled function.
# tf.range does not.
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype

Check warning on line 182 in keras_core/backend/tensorflow/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/tensorflow/numpy.py#L182

Added line #L182 was not covered by tests
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return tf.range(start, stop, delta=step, dtype=dtype)


Expand Down Expand Up @@ -749,6 +757,9 @@


def sqrt(x):
x = convert_to_tensor(x)
if tf.as_dtype(x.dtype).is_integer:
x = tf.cast(x, dtype=config.floatx())
return tfnp.sqrt(x)


Expand Down
10 changes: 9 additions & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch

from keras_core.backend import config
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import get_device
Expand Down Expand Up @@ -91,7 +92,7 @@

def zeros_like(x, dtype=None):
x = convert_to_tensor(x)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype or x.dtype)
return torch.zeros_like(x, dtype=dtype)


Expand Down Expand Up @@ -160,6 +161,13 @@


def arange(start, stop=None, step=1, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype

Check warning on line 166 in keras_core/backend/torch/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/torch/numpy.py#L166

Added line #L166 was not covered by tests
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
dtype = to_torch_dtype(dtype)
if stop is None:
return torch.arange(end=start, dtype=dtype, device=get_device())
Expand Down
36 changes: 33 additions & 3 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3571,9 +3571,37 @@ def test_split(self):
self.assertEqual(len(knp.Split(2)(x)), 2)

def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]])
self.assertAllClose(knp.sqrt(x), np.sqrt(x))
self.assertAllClose(knp.Sqrt()(x), np.sqrt(x))
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

@pytest.mark.skipif(
backend.backend() == "jax", reason="JAX does not support float64."
)
def test_sqrt_float64(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float64")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float64")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float64")
self.assertAllClose(y, ref_y)

def test_sqrt_int32(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="int32")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

def test_stack(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down Expand Up @@ -3704,6 +3732,8 @@ def test_arange(self):
self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7))
self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2))

self.assertEqual(standardize_dtype(knp.arange(3).dtype), "int32")

def test_full(self):
self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0))
self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1))
Expand Down