diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py index d0dd012ec..efafb6e91 100644 --- a/keras_core/backend/jax/math.py +++ b/keras_core/backend/jax/math.py @@ -248,3 +248,7 @@ def istft( def rsqrt(x): return jax.lax.rsqrt(x) + + +def erf(x): + return jax.lax.erf(x) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index 2e1c2cfcd..ce8cdb4d2 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -302,3 +302,7 @@ def istft( def rsqrt(x): return 1.0 / np.sqrt(x) + + +def erf(x): + return np.array(scipy.special.erf(x)) diff --git a/keras_core/backend/tensorflow/math.py b/keras_core/backend/tensorflow/math.py index d7139fe55..573878dbb 100644 --- a/keras_core/backend/tensorflow/math.py +++ b/keras_core/backend/tensorflow/math.py @@ -239,3 +239,7 @@ def istft( def rsqrt(x): return tf.math.rsqrt(x) + + +def erf(x): + return tf.math.erf(x) diff --git a/keras_core/backend/torch/math.py b/keras_core/backend/torch/math.py index eecdc4902..09b47e6ca 100644 --- a/keras_core/backend/torch/math.py +++ b/keras_core/backend/torch/math.py @@ -408,3 +408,8 @@ def istft( def rsqrt(x): x = convert_to_tensor(x) return torch.rsqrt(x) + + +def erf(x): + x = convert_to_tensor(x) + return torch.erf(x) diff --git a/keras_core/ops/math.py b/keras_core/ops/math.py index 413df3d33..170a602b0 100644 --- a/keras_core/ops/math.py +++ b/keras_core/ops/math.py @@ -929,3 +929,39 @@ def rsqrt(x): return Rsqrt().symbolic_call(x) x = backend.convert_to_tensor(x) return backend.math.rsqrt(x) + + +class Erf(Operation): + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x): + return backend.math.erf(x) + + +@keras_core_export("keras_core.ops.erf") +def erf(x): + """Computes the error function of x element-wise. + + Args: + x: input tensor + + Returns: + A tensor with the same type as `x`. + + Examples: + + Basic usage + >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + >>> y = Erf()(x) + + Using `float32` data type + >>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32) + >>> y_float32 = Erf()(x_float32) + + Using large values + >>> x_large = np.array([1e10, -1e10]) + >>> y_large = Erf()(x_large) + """ + + return Erf()(x) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 9fdac5106..32e9689a2 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -835,3 +835,41 @@ def test_rsqrt(self): x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) + + def test_erf_operation_basic(self): + # Sample values for testing + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + + # Expected output using numpy's approximation of the error function + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + sample_values + ) + + # Output from the erf operation in keras_core + output_from_erf_op = kmath.erf(sample_values) + + # Assert that the outputs are close + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) + + def test_erf_operation_dtype(self): + # Test for float32 and float64 data types + for dtype in ("float32", "float64"): + sample_values = np.array( + [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype + ) + expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + sample_values + ) + output_from_erf_op = kmath.erf(sample_values) + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) + + def test_erf_operation_edge_cases(self): + # Test for edge cases + edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64) + expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( + edge_values + ) + output_from_edge_erf_op = kmath.erf(edge_values) + self.assertAllClose( + expected_edge_output, output_from_edge_erf_op, atol=1e-4 + )