Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added erf op to math.py #908

Closed
wants to merge 14 commits into from
Closed
4 changes: 4 additions & 0 deletions keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,7 @@ def istft(

def rsqrt(x):
return jax.lax.rsqrt(x)


def erf(x):
return jnp.erf(x)
4 changes: 4 additions & 0 deletions keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,7 @@ def istft(

def rsqrt(x):
return 1.0 / np.sqrt(x)


def erf(x):
return scipy.special.erf(x)
4 changes: 4 additions & 0 deletions keras_core/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,7 @@ def istft(

def rsqrt(x):
return tf.math.rsqrt(x)


def erf(x):
return tf.math.erf(x)
6 changes: 6 additions & 0 deletions keras_core/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,9 @@ def istft(
def rsqrt(x):
x = convert_to_tensor(x)
return torch.rsqrt(x)


def erf(x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do x = convert_to_tensor(x) unconditionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return torch.erf(x)
38 changes: 38 additions & 0 deletions keras_core/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,41 @@ def rsqrt(x):
return Rsqrt().symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.math.rsqrt(x)


class Erf(Operation):
"""Computes the error function of x element-wise.

Args:
input_tensor: A tensor of type `float32` or `float64`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can have more types, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edited the comments based on that


Returns:
A tensor of the same shape and type as `input_tensor`.

Examples:

# Basic usage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're not printing any outputs, just use a fenced code block for the code example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
>>> y = Erf()(x)
# Using `float32` data type
>>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32)
>>> y_float32 = Erf()(x_float32)
# Using large values
>>> x_large = np.array([1e10, -1e10])
>>> y_large = Erf()(x_large)
"""

def __init__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it

super().__init__()

def compute_output_spec(self, input_tensor):
return KerasTensor(shape=input_tensor.shape, dtype=input_tensor.dtype)

def call(self, input_tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with x

return backend.erf(input_tensor)


@keras_core_export("keras_core.ops.erf")
def erf(x):
"""Functional interface to the `Erf` operation."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the docstring should be, not the op above, since this is the public symbol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relocated it!

return Erf()(x)
38 changes: 38 additions & 0 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,41 @@ def test_rsqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x))
self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x))

def test_erf_operation_basic(self):
# Sample values for testing
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])

# Expected output using numpy's approximation of the error function
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
sample_values
)

# Output from the erf operation in keras_core
output_from_erf_op = kmath.erf(sample_values).numpy()

# Assert that the outputs are close
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)

def test_erf_operation_dtype(self):
# Test for float32 and float64 data types
for dtype in [np.float32, np.float64]:
sqali marked this conversation as resolved.
Show resolved Hide resolved
sample_values = np.array(
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype
)
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
sample_values
)
output_from_erf_op = kmath.erf(sample_values).numpy()
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)

def test_erf_operation_edge_cases(self):
# Test for edge cases
edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your test values are too large. Try 1e5. This the source of the large discrepancy IMO.

Copy link
Contributor Author

@sqali sqali Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have implemented the changes, but I can see from the tests that it is failing for the below array examples. I wonder if there is anything wrong in the implementation function itself.

  • x: array([ 1.128379e+00, -1.128379e+00, 1.273240e-05, -1.273240e-05])
  • x: array([-1.128354, -1.123101, -0.950886, 0. , 0.950886, 1.123101, 1.128354])

image

expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
edge_values
)
output_from_edge_erf_op = kmath.erf(edge_values).numpy()
self.assertAllClose(
expected_edge_output, output_from_edge_erf_op, atol=1e-5
)
Loading