From 8a1f807990848a9ed9cdc1f869444f3cac47db8e Mon Sep 17 00:00:00 2001 From: Jay Choy <91728831+ZJay07@users.noreply.github.com> Date: Sun, 24 Mar 2024 12:07:27 +0000 Subject: [PATCH 1/4] fixed jax.sinc for tensorflow frontend --- .../backends/tensorflow/experimental/elementwise.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index 977e0d0584c87..82eace92d0b2f 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -66,6 +66,10 @@ def lgamma( return tf.math.lgamma(x) +@with_unsupported_dtypes( + {"2.15.0 and below": ("bfloat16",)}, + backend_version, +) def sinc( x: Union[tf.Tensor, tf.Variable], /, From f480a017e1867a93396ecc5d2ca8416cfcba8c0a Mon Sep 17 00:00:00 2001 From: Jay Choy <91728831+ZJay07@users.noreply.github.com> Date: Mon, 1 Apr 2024 19:57:14 +0000 Subject: [PATCH 2/4] proper fix --- ivy/functional/backends/tensorflow/experimental/elementwise.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index 82eace92d0b2f..c5589e33e8d21 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -76,8 +76,7 @@ def sinc( *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: - x = ivy.pi * x - return tf.cast(tf.where(x == 0, 1, tf.math.sin(x) / x), x.dtype) + return tf.experimental.numpy.sinc(x) @with_supported_dtypes( From 895695967ad2b1295af17ceaf86c998b40663c05 Mon Sep 17 00:00:00 2001 From: Jay Choy <91728831+ZJay07@users.noreply.github.com> Date: Mon, 1 Apr 2024 20:05:22 +0000 Subject: [PATCH 3/4] cast back to same dtype as x --- .../backends/tensorflow/experimental/elementwise.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index c5589e33e8d21..4b803b372a1a7 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -66,17 +66,13 @@ def lgamma( return tf.math.lgamma(x) -@with_unsupported_dtypes( - {"2.15.0 and below": ("bfloat16",)}, - backend_version, -) def sinc( x: Union[tf.Tensor, tf.Variable], /, *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: - return tf.experimental.numpy.sinc(x) + return tf.cast(tf.experimental.numpy.sinc(x), x.dtype) @with_supported_dtypes( From 9ea2cd20cd7e354c2941dc0f94f87cb6d2978af8 Mon Sep 17 00:00:00 2001 From: Jay Choy <91728831+ZJay07@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:25:17 +0000 Subject: [PATCH 4/4] fixed additional test --- ivy/functional/backends/paddle/experimental/elementwise.py | 2 +- .../test_ivy/test_frontends/test_torch/test_pointwise_ops.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/paddle/experimental/elementwise.py b/ivy/functional/backends/paddle/experimental/elementwise.py index c26898376ea69..e378ed820779e 100644 --- a/ivy/functional/backends/paddle/experimental/elementwise.py +++ b/ivy/functional/backends/paddle/experimental/elementwise.py @@ -89,7 +89,7 @@ def fmax( @with_unsupported_device_and_dtypes( - {"2.6.0 and below": {"cpu": ("float16",)}}, backend_version + {"2.6.0 and below": {"cpu": ("float16", "bfloat16")}}, backend_version ) def sinc(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: y = ivy.pi * paddle.where(x == 0, paddle.to_tensor(1.0e-20, dtype=x.dtype), x) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py index 52c15b10bd760..72d820202da6b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py @@ -2716,6 +2716,8 @@ def test_torch_sinc( fn_tree=fn_tree, on_device=on_device, input=x[0], + atol=1e-02, + rtol=1e-02, )