Skip to content

Commit

Permalink
feat: implement inverse fast fourier transform (irfftn) function for …
Browse files Browse the repository at this point in the history
…paddle frontend (#23526)
  • Loading branch information
xingshuodong authored Sep 19, 2023
1 parent a4c2ed4 commit 1bd6928
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
52 changes: 52 additions & 0 deletions ivy/functional/frontends/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,58 @@ def irfft(x, n=None, axis=-1.0, norm="backward", name=None):
return time_domain


@with_supported_dtypes(
{"2.5.1 and below": ("complex64", "complex128")},
"paddle",
)
@to_ivy_arrays_and_back
def irfftn(x, s=None, axes=None, norm="backward", name=None):
x = ivy.array(x)

if axes is None:
axes = list(range(len(x.shape)))

include_last_axis = len(x.shape) - 1 in axes

if s is None:
s = [
x.shape[axis] if axis != (len(x.shape) - 1) else 2 * (x.shape[axis] - 1)
for axis in axes
]

real_result = x
remaining_axes = [axis for axis in axes if axis != (len(x.shape) - 1)]

if remaining_axes:
real_result = ivy.ifftn(
x,
s=[s[axes.index(axis)] for axis in remaining_axes],
axes=remaining_axes,
norm=norm,
)

if include_last_axis:
axis = len(x.shape) - 1
size = s[axes.index(axis)]
freq_domain = ivy.moveaxis(real_result, axis, -1)
slices = [slice(None)] * ivy.get_num_dims(freq_domain)
slices[-1] = slice(0, size // 2 + 1)
pos_freq_terms = freq_domain[tuple(slices)]
slices[-1] = slice(1, -1)
neg_freq_terms = ivy.conj(pos_freq_terms[tuple(slices)][..., ::-1])
combined_freq_terms = ivy.concat((pos_freq_terms, neg_freq_terms), axis=-1)
real_result = ivy.ifftn(combined_freq_terms, s=[size], axes=[-1], norm=norm)
real_result = ivy.moveaxis(real_result, -1, axis)

if ivy.is_complex_dtype(x.dtype):
output_dtype = "float32" if x.dtype == "complex64" else "float64"
else:
output_dtype = "float32"

result_t = ivy.astype(real_result, output_dtype)
return result_t


@to_ivy_arrays_and_back
def rfftfreq(n, d=1.0, dtype=None, name=None):
dtype = ivy.default_dtype()
Expand Down
37 changes: 37 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,43 @@ def test_paddle_irfft(
)


@handle_frontend_test(
fn_tree="paddle.fft.irfftn",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("complex"),
min_value=-10,
max_value=10,
min_num_dims=1,
max_num_dims=5,
min_dim_size=2,
max_dim_size=5,
valid_axis=True,
force_int_axis=True,
),
norm=st.sampled_from(["backward", "ortho", "forward"]),
)
def test_paddle_irfftn(
dtype_x_axis,
norm,
frontend,
test_flags,
fn_tree,
backend_fw,
):
input_dtypes, x, axis = dtype_x_axis
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
x=x[0],
s=None,
axes=None,
norm=norm,
)


@handle_frontend_test(
fn_tree="paddle.fft.rfftfreq",
n=st.integers(min_value=1, max_value=1000),
Expand Down

0 comments on commit 1bd6928

Please sign in to comment.