Skip to content

Commit

Permalink
Change logic for testing for complex input
Browse files Browse the repository at this point in the history
Suggested from code review. Uses paddle's `is_complex()` function

Co-authored-by: Mahmoud Ashraf <[email protected]>
  • Loading branch information
jshepherd01 and MahmoudAshraf97 authored Aug 16, 2023
1 parent 36cefa6 commit 78f158b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def gelu(
approximate: bool = False,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
sqrt_2_over_pi = 0.7978845608
# the other magic number comes directly from the formula in
# https://doi.org/10.48550/arXiv.1606.08415
Expand Down
26 changes: 13 additions & 13 deletions ivy/functional/backends/paddle/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def asin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
]:
ret_dtype = x.dtype
return paddle.asin(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
asinh_iz = paddle_backend.asinh(paddle.complex(-x.imag(), x.real()))
return paddle.complex(asinh_iz.imag(), -asinh_iz.real())
return paddle.asin(x)
Expand All @@ -231,7 +231,7 @@ def asinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle
]:
ret_dtype = x.dtype
return paddle.asinh(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
# From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L276 # noqa
s1 = paddle_backend.sqrt(paddle.complex(1 + x.imag(), -x.real()))
s2 = paddle_backend.sqrt(paddle.complex(1 - x.imag(), x.real()))
Expand Down Expand Up @@ -304,7 +304,7 @@ def cosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
]:
ret_dtype = x.dtype
return paddle.cosh(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
re = x.real()
im = x.imag()
return paddle.complex(
Expand Down Expand Up @@ -434,7 +434,7 @@ def cos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T
]:
ret_dtype = x.dtype
return paddle.cos(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
re = x.real()
im = x.imag()
return paddle.complex(
Expand Down Expand Up @@ -537,7 +537,7 @@ def acos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
paddle.float16,
]:
return paddle.acos(x.astype("float32")).astype(x.dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
# From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L178 # noqa
s1 = paddle_backend.sqrt(1 - x)
s2 = paddle_backend.sqrt(1 + x)
Expand Down Expand Up @@ -617,7 +617,7 @@ def acosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle
paddle.float16,
]:
return paddle.acosh(x.astype("float32")).astype(x.dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
# From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L221 # noqa
s1 = paddle_backend.sqrt(paddle.complex(x.real() - 1, x.imag()))
s2 = paddle_backend.sqrt(paddle.complex(x.real() + 1, x.imag()))
Expand All @@ -642,7 +642,7 @@ def sin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T
paddle.float16,
]:
return paddle.sin(x.astype("float32")).astype(x.dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
re = x.real()
im = x.imag()
return paddle.complex(
Expand Down Expand Up @@ -687,7 +687,7 @@ def tanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
paddle.float16,
]:
return paddle.tanh(x.astype("float32")).astype(x.dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
tanh_a = paddle.tanh(paddle.real(x))
tan_b = paddle.tan(paddle.imag(x))
return (tanh_a + 1j * tan_b) / (1 + 1j * (tanh_a * tan_b))
Expand Down Expand Up @@ -733,7 +733,7 @@ def sinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
]:
ret_dtype = x.dtype
return paddle.sinh(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
re = x.real()
im = x.imag()
return paddle.complex(
Expand All @@ -757,7 +757,7 @@ def square(
) -> paddle.Tensor:
if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]:
return paddle.square(x)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
return paddle.complex(
paddle.square(paddle.real(x)) - paddle.square(paddle.imag(x)),
2.0 * paddle.real(x) * paddle.imag(x),
Expand Down Expand Up @@ -979,7 +979,7 @@ def tan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T
]:
ret_dtype = x.dtype
return paddle.tan(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
tanh_ix = paddle_backend.tanh(paddle.complex(-x.imag(), x.real()))
return paddle.complex(tanh_ix.imag(), -tanh_ix.real())
return paddle.tan(x)
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def log(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T
def exp(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]:
return paddle.exp(x)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
return paddle.multiply(
paddle.exp(x.real()),
paddle.complex(paddle.cos(x.imag()), paddle.sin(x.imag())),
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def atanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle
]:
ret_dtype = x.dtype
return paddle.atanh(x.astype("float32")).astype(ret_dtype)
if x.dtype in [paddle.complex64, paddle.complex128]:
if paddle.is_complex(x):
return 0.5 * (paddle_backend.log(1 + x) - paddle_backend.log(1 - x))
return paddle.atanh(x)

Expand Down

0 comments on commit 78f158b

Please sign in to comment.