Skip to content

Commit

Permalink
feat: Add stateful layer for idct (#26500)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <[email protected]>
Co-authored-by: Sam-Armstrong <[email protected]>
  • Loading branch information
3 people committed Jul 11, 2024
1 parent 7950ca6 commit 59027b0
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
68 changes: 68 additions & 0 deletions ivy/stateful/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,74 @@ def _extra_repr(self):
return s.format(**self.__dict__)


class IDct(Module):
def __init__(
self,
*,
type=2,
n=None,
axis=-1,
norm=None,
device=None,
dtype=None,
):
"""
Class for applying the Discrete Cosine Transform over mini-batch of inputs.
Parameters
----------
x
The input signal.
type
The type of the idct. Must be 1, 2, 3 or 4.
n
The length of the transform. If n is less than the input signal length,
then x is truncated, if n is larger then x is zero-padded.
axis
The axis to compute the IDCT along.
norm
The type of normalization to be applied. Must be either None or "ortho".
device
device on which to create the layer's variables 'cuda:0', 'cuda:1', 'cpu'
"""
self.type = type
self.n = n
self.axis = axis
self.norm = norm
Module.__init__(self, device=device, dtype=dtype)

def _forward(self, x):
"""
Forward pass of the layer.
Parameters
----------
x
The input array to the layer.
Returns
-------
The output array of the layer.
"""
return ivy.idct(
x,
type=self.type,
n=self.n,
axis=self.axis,
norm=self.norm,
)

def extra_repr(self):
s = "type={type}"
if self.n is not None:
s += ", n={n}"
if self.axis != -1:
s += ", axis={axis}"
if self.norm is not None:
s += ", norm={norm}"
return s.format(**self.__dict__)


# EMBEDDING #
# ----------#

Expand Down
39 changes: 39 additions & 0 deletions ivy_tests/test_ivy/test_stateful/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,45 @@ def test_dct(
)


@handle_method(
method_tree="IDct.__call__",
dtype_x_and_args=_valid_dct(),
)
def test_idct(
*,
dtype_x_and_args,
test_gradients,
on_device,
class_name,
method_name,
ground_truth_backend,
init_flags,
method_flags,
backend_fw,
):
dtype, x, type, n, axis, norm = dtype_x_and_args
helpers.test_method(
ground_truth_backend=ground_truth_backend,
backend_to_test=backend_fw,
init_flags=init_flags,
method_flags=method_flags,
init_all_as_kwargs_np={
"dtype": dtype[0],
"type": type,
"n": n,
"axis": axis,
"norm": norm,
"device": on_device,
},
method_input_dtypes=dtype,
method_all_as_kwargs_np={"x": x[0]},
class_name=class_name,
method_name=method_name,
test_gradients=test_gradients,
on_device=on_device,
)


# # depthwise conv2d
@handle_method(
method_tree="DepthwiseConv2D.__call__",
Expand Down

0 comments on commit 59027b0

Please sign in to comment.