From 6485433201e8a935f462aca191bc37dd7cf00c30 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Sun, 17 Sep 2023 17:50:24 +0530 Subject: [PATCH 01/10] added erf op to math.py --- keras_core/backend/tensorflow/math.py | 4 +++ keras_core/ops/math.py | 35 +++++++++++++++++++++++++++ keras_core/ops/math_test.py | 31 ++++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/keras_core/backend/tensorflow/math.py b/keras_core/backend/tensorflow/math.py index d7139fe55..3e503c9ec 100644 --- a/keras_core/backend/tensorflow/math.py +++ b/keras_core/backend/tensorflow/math.py @@ -239,3 +239,7 @@ def istft( def rsqrt(x): return tf.math.rsqrt(x) + + +def erf(x): + return tf.math.erf(x) \ No newline at end of file diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 2e1f28ab0..9962c0e20 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -929,3 +929,38 @@ def rsqrt(x): return Rsqrt().symbolic_call(x) x = backend.convert_to_tensor(x) return backend.math.rsqrt(x) + + +class Erf(Operation): + """Computes the error function of x element-wise. + + Args: + input_tensor: A tensor of type `float32` or `float64`. + + Returns: + A tensor of the same shape and type as `input_tensor`, containing the error function values. + + Examples: + + # Basic usage + >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + >>> y = Erf()(x) + # Using `float32` data type + >>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32) + >>> y_float32 = Erf()(x_float32) + # Using large values + >>> x_large = np.array([1e10, -1e10]) + >>> y_large = Erf()(x_large) + """ + def __init__(self): + super().__init__() + + def compute_output_spec(self, input_tensor): + return KerasTensor(shape=input_tensor.shape, dtype=input_tensor.dtype) + + def call(self, input_tensor): + return backend.erf(input_tensor) + +def erf(x): + """Functional interface to the `Erf` operation.""" + return Erf()(x) \ No newline at end of file diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 9fdac5106..7f09dc397 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -835,3 +835,34 @@ def test_rsqrt(self): x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) + + +class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): + + def test_erf_operation_basic(self): + # Sample values for testing + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + + # Expected output using numpy's approximation of the error function + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) + + # Output from the erf operation in keras_core + output_from_erf_op = kmath.erf(sample_values).numpy() + + # Assert that the outputs are close + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) + + def test_erf_operation_dtype(self): + # Test for float32 and float64 data types + for dtype in [np.float32, np.float64]: + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype) + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) + output_from_erf_op = kmath.erf(sample_values).numpy() + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) + + def test_erf_operation_edge_cases(self): + # Test for edge cases + edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64) + expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(edge_values) + output_from_edge_erf_op = kmath.erf(edge_values).numpy() + self.assertAllClose(expected_edge_output, output_from_edge_erf_op, atol=1e-5) \ No newline at end of file From 61c9690a1b91d5e70927b2b9862cd5525b5fd1b6 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Mon, 18 Sep 2023 19:29:27 +0530 Subject: [PATCH 02/10] code reformatting done, introduced erf function in jax, torch and numpy --- keras_core/backend/jax/math.py | 4 ++++ keras_core/backend/numpy/math.py | 4 ++++ keras_core/backend/tensorflow/math.py | 2 +- keras_core/backend/torch/math.py | 6 ++++++ keras_core/ops/math.py | 6 ++++-- keras_core/ops/math_test.py | 21 +++++++++++++++------ 6 files changed, 34 insertions(+), 9 deletions(-) diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py index d0dd012ec..7ff0a729e 100644 --- a/keras_core/backend/jax/math.py +++ b/keras_core/backend/jax/math.py @@ -248,3 +248,7 @@ def istft( def rsqrt(x): return jax.lax.rsqrt(x) + + +def erf(x): + return jnp.erf(x) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index 2e1c2cfcd..d7e6fd3c4 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -302,3 +302,7 @@ def istft( def rsqrt(x): return 1.0 / np.sqrt(x) + + +def erf(x): + return scipy.special.erf(x) diff --git a/keras_core/backend/tensorflow/math.py b/keras_core/backend/tensorflow/math.py index 3e503c9ec..573878dbb 100644 --- a/keras_core/backend/tensorflow/math.py +++ b/keras_core/backend/tensorflow/math.py @@ -242,4 +242,4 @@ def rsqrt(x): def erf(x): - return tf.math.erf(x) \ No newline at end of file + return tf.math.erf(x) diff --git a/keras_core/backend/torch/math.py b/keras_core/backend/torch/math.py index eecdc4902..e9d9ec6f0 100644 --- a/keras_core/backend/torch/math.py +++ b/keras_core/backend/torch/math.py @@ -408,3 +408,9 @@ def istft( def rsqrt(x): x = convert_to_tensor(x) return torch.rsqrt(x) + + +def erf(x): + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + return torch.erf(x) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 9962c0e20..323e58112 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -936,7 +936,7 @@ class Erf(Operation): Args: input_tensor: A tensor of type `float32` or `float64`. - + Returns: A tensor of the same shape and type as `input_tensor`, containing the error function values. @@ -952,6 +952,7 @@ class Erf(Operation): >>> x_large = np.array([1e10, -1e10]) >>> y_large = Erf()(x_large) """ + def __init__(self): super().__init__() @@ -961,6 +962,7 @@ def compute_output_spec(self, input_tensor): def call(self, input_tensor): return backend.erf(input_tensor) + def erf(x): """Functional interface to the `Erf` operation.""" - return Erf()(x) \ No newline at end of file + return Erf()(x) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 7f09dc397..3b5660b1e 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -838,13 +838,14 @@ def test_rsqrt(self): class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): - def test_erf_operation_basic(self): # Sample values for testing sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) # Expected output using numpy's approximation of the error function - expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + sample_values + ) # Output from the erf operation in keras_core output_from_erf_op = kmath.erf(sample_values).numpy() @@ -855,14 +856,22 @@ def test_erf_operation_basic(self): def test_erf_operation_dtype(self): # Test for float32 and float64 data types for dtype in [np.float32, np.float64]: - sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype) - expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) + sample_values = np.array( + [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype + ) + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + sample_values + ) output_from_erf_op = kmath.erf(sample_values).numpy() self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) def test_erf_operation_edge_cases(self): # Test for edge cases edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64) - expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(edge_values) + expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + edge_values + ) output_from_edge_erf_op = kmath.erf(edge_values).numpy() - self.assertAllClose(expected_edge_output, output_from_edge_erf_op, atol=1e-5) \ No newline at end of file + self.assertAllClose( + expected_edge_output, output_from_edge_erf_op, atol=1e-5 + ) From 79601064739edc6f110d2e590357fd4d2eaebba9 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Mon, 18 Sep 2023 19:40:09 +0530 Subject: [PATCH 03/10] shortened comment lines, renamed the new function --- keras_core/ops/math.py | 2 +- keras_core/ops/math_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 323e58112..84fa1fb0f 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -938,7 +938,7 @@ class Erf(Operation): input_tensor: A tensor of type `float32` or `float64`. Returns: - A tensor of the same shape and type as `input_tensor`, containing the error function values. + A tensor of the same shape and type as `input_tensor`. Examples: diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 3b5660b1e..f5cb49d4d 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -837,7 +837,7 @@ def test_rsqrt(self): self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) -class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): +class ErfFunctionTests(testing.TestCase, parameterized.TestCase): def test_erf_operation_basic(self): # Sample values for testing sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) From d9f99f27e3726261d8187ba767ee8b9fb9eb1ff2 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Mon, 18 Sep 2023 20:06:28 +0530 Subject: [PATCH 04/10] added decorator above the erf function, added tests under the existing MathOpsCorrectnessTest class --- keras_core/ops/math.py | 1 + keras_core/ops/math_test.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 84fa1fb0f..c932cf129 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -963,6 +963,7 @@ def call(self, input_tensor): return backend.erf(input_tensor) +@keras_core_export("keras_core.ops.erf") def erf(x): """Functional interface to the `Erf` operation.""" return Erf()(x) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index f5cb49d4d..1eb33aee2 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -836,8 +836,6 @@ def test_rsqrt(self): self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) - -class ErfFunctionTests(testing.TestCase, parameterized.TestCase): def test_erf_operation_basic(self): # Sample values for testing sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) From 8d4b8ebd24a60876a79d5ae22296885516acbcec Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Tue, 19 Sep 2023 23:38:34 +0530 Subject: [PATCH 05/10] changes done to erf functions as per suggestion --- keras_core/backend/torch/math.py | 3 +-- keras_core/ops/math.py | 26 +++++++++++--------------- keras_core/ops/math_test.py | 2 +- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/keras_core/backend/torch/math.py b/keras_core/backend/torch/math.py index e9d9ec6f0..09b47e6ca 100644 --- a/keras_core/backend/torch/math.py +++ b/keras_core/backend/torch/math.py @@ -411,6 +411,5 @@ def rsqrt(x): def erf(x): - if not isinstance(x, torch.Tensor): - x = torch.tensor(x) + x = convert_to_tensor(x) return torch.erf(x) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index c932cf129..569b05589 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -932,13 +932,22 @@ def rsqrt(x): class Erf(Operation): + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x): + return backend.erf(x) + + +@keras_core_export("keras_core.ops.erf") +def erf(x): """Computes the error function of x element-wise. Args: - input_tensor: A tensor of type `float32` or `float64`. + x: input tensor Returns: - A tensor of the same shape and type as `input_tensor`. + A tensor with the same type as `x`. Examples: @@ -953,17 +962,4 @@ class Erf(Operation): >>> y_large = Erf()(x_large) """ - def __init__(self): - super().__init__() - - def compute_output_spec(self, input_tensor): - return KerasTensor(shape=input_tensor.shape, dtype=input_tensor.dtype) - - def call(self, input_tensor): - return backend.erf(input_tensor) - - -@keras_core_export("keras_core.ops.erf") -def erf(x): - """Functional interface to the `Erf` operation.""" return Erf()(x) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 1eb33aee2..14324cbe3 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -853,7 +853,7 @@ def test_erf_operation_basic(self): def test_erf_operation_dtype(self): # Test for float32 and float64 data types - for dtype in [np.float32, np.float64]: + for dtype in ("float32", "float64"): sample_values = np.array( [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype ) From 6b52b8afb6f0106abba9fe0b239141f092f11e6f Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Thu, 21 Sep 2023 13:42:52 +0530 Subject: [PATCH 06/10] backend.math added to erf function, fenced code block used fo --- keras_core/ops/math.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 96d2d80d1..dfe3e89ef 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -936,7 +936,7 @@ def compute_output_spec(self, x): return KerasTensor(shape=x.shape, dtype=x.dtype) def call(self, x): - return backend.erf(x) + return backend.math.erf(x) @keras_core_export("keras_core.ops.erf") @@ -951,13 +951,15 @@ def erf(x): Examples: - # Basic usage + Basic usage >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) >>> y = Erf()(x) - # Using `float32` data type + + Using `float32` data type >>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32) >>> y_float32 = Erf()(x_float32) - # Using large values + + Using large values >>> x_large = np.array([1e10, -1e10]) >>> y_large = Erf()(x_large) """ From 7f764597f24f09bd9b89b0c586ced1ecb9955cad Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Thu, 21 Sep 2023 13:44:27 +0530 Subject: [PATCH 07/10] code reformatting applied --- keras_core/ops/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index dfe3e89ef..170a602b0 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -958,7 +958,7 @@ def erf(x): Using `float32` data type >>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32) >>> y_float32 = Erf()(x_float32) - + Using large values >>> x_large = np.array([1e10, -1e10]) >>> y_large = Erf()(x_large) From 5efabf539a4eac8288e563c1a3c00854b8d5c8fb Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Thu, 21 Sep 2023 16:35:27 +0530 Subject: [PATCH 08/10] corrected jax expression, numpy function --- keras_core/backend/jax/math.py | 2 +- keras_core/backend/numpy/math.py | 2 +- keras_core/ops/math_test.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py index 7ff0a729e..efafb6e91 100644 --- a/keras_core/backend/jax/math.py +++ b/keras_core/backend/jax/math.py @@ -251,4 +251,4 @@ def rsqrt(x): def erf(x): - return jnp.erf(x) + return jax.lax.erf(x) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index d7e6fd3c4..ce8cdb4d2 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -305,4 +305,4 @@ def rsqrt(x): def erf(x): - return scipy.special.erf(x) + return np.array(scipy.special.erf(x)) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 14324cbe3..1b5ffcb2c 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -846,7 +846,7 @@ def test_erf_operation_basic(self): ) # Output from the erf operation in keras_core - output_from_erf_op = kmath.erf(sample_values).numpy() + output_from_erf_op = kmath.erf(sample_values) # Assert that the outputs are close self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) @@ -860,7 +860,7 @@ def test_erf_operation_dtype(self): expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( sample_values ) - output_from_erf_op = kmath.erf(sample_values).numpy() + output_from_erf_op = kmath.erf(sample_values) self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) def test_erf_operation_edge_cases(self): @@ -869,7 +869,7 @@ def test_erf_operation_edge_cases(self): expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( edge_values ) - output_from_edge_erf_op = kmath.erf(edge_values).numpy() + output_from_edge_erf_op = kmath.erf(edge_values) self.assertAllClose( expected_edge_output, output_from_edge_erf_op, atol=1e-5 ) From 3f199c343407f655f5e4e2a08589b4443aca6f75 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 22 Sep 2023 09:08:57 +0530 Subject: [PATCH 09/10] Update math_test.py lowered the precision tolerance to 1e-4 for erf function --- keras_core/ops/math_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 1b5ffcb2c..96cc7022a 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -849,7 +849,7 @@ def test_erf_operation_basic(self): output_from_erf_op = kmath.erf(sample_values) # Assert that the outputs are close - self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) def test_erf_operation_dtype(self): # Test for float32 and float64 data types @@ -861,7 +861,7 @@ def test_erf_operation_dtype(self): sample_values ) output_from_erf_op = kmath.erf(sample_values) - self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) def test_erf_operation_edge_cases(self): # Test for edge cases @@ -871,5 +871,5 @@ def test_erf_operation_edge_cases(self): ) output_from_edge_erf_op = kmath.erf(edge_values) self.assertAllClose( - expected_edge_output, output_from_edge_erf_op, atol=1e-5 + expected_edge_output, output_from_edge_erf_op, atol=1e-4 ) From 0307b744f0c318614068c950da02ec2ede4e5d24 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 22 Sep 2023 22:42:20 +0530 Subject: [PATCH 10/10] Update math_test.py lowered the test values to 1e5 --- keras_core/ops/math_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 96cc7022a..32e9689a2 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -865,7 +865,7 @@ def test_erf_operation_dtype(self): def test_erf_operation_edge_cases(self): # Test for edge cases - edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64) + edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64) expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( edge_values )