From 74c13af03dc912f2ffd9391a5cbdf6230b42674b Mon Sep 17 00:00:00 2001 From: Kazeem Hakeem Date: Tue, 15 Aug 2023 17:14:39 +0100 Subject: [PATCH] BlackmanWindow #19480 (#19882) Co-authored-by: ivy-branch Co-authored-by: Samsam Lee <106169847+jieunboy0516@users.noreply.github.com> --- .../array/experimental/creation.py | 46 +++++++++ .../container/experimental/creation.py | 93 +++++++++++++++++++ .../backends/jax/experimental/creation.py | 20 ++++ .../backends/mxnet/experimental/creation.py | 11 +++ .../backends/numpy/experimental/creation.py | 24 +++++ .../backends/paddle/experimental/creation.py | 22 +++++ .../tensorflow/experimental/creation.py | 20 ++++ .../backends/torch/experimental/creation.py | 20 ++++ ivy/functional/ivy/experimental/creation.py | 49 ++++++++++ .../test_core/test_creation.py | 29 ++++++ 10 files changed, 334 insertions(+) diff --git a/ivy/data_classes/array/experimental/creation.py b/ivy/data_classes/array/experimental/creation.py index 6099906f9c4a0..1c9bfb6502429 100644 --- a/ivy/data_classes/array/experimental/creation.py +++ b/ivy/data_classes/array/experimental/creation.py @@ -123,3 +123,49 @@ def unsorted_segment_sum( `segment_ids` equals to segment ID. """ return ivy.unsorted_segment_sum(self._data, segment_ids, num_segments) + + def blackman_window( + self: ivy.Array, + /, + *, + periodic: bool = True, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + device: Optional[Union[ivy.Device, ivy.NativeDevice]] = None, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.blackman_window. This method simply wraps the + function, and so the docstring for ivy.blackman_window also applies to this method with + minimal changes. + Parameters + ---------- + self + int. + periodic + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + Default: ``True``. + dtype + output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``self``. Default: ``None``. + device + device on which to place the created array. If ``device`` is ``None``, the + output array device must be inferred from ``self``. Default: ``None``. + out + optional output array, for writing the result to. It must have a shape that + the inputs broadcast to. + Returns + ------- + ret + The array containing the window. + Examples + -------- + >>> ivy.blackman_window(4, periodic = True) + ivy.array([-1.38777878e-17, 3.40000000e-01, 1.00000000e+00, 3.40000000e-01]) + >>> ivy.blackman_window(7, periodic = False) + ivy.array([-1.38777878e-17, 1.30000000e-01, 6.30000000e-01, 1.00000000e+00, + 6.30000000e-01, 1.30000000e-01, -1.38777878e-17]) + """ + return ivy.blackman_window( + self._data, periodic=periodic, dtype=dtype, device=device, out=out + ) \ No newline at end of file diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index a2b312be1cea1..29d999b38c295 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -954,3 +954,96 @@ def unsorted_segment_sum( segment_ids, num_segments, ) + + @staticmethod + def static_blackman_window( + window_length: Union[int, ivy.Container], + periodic: bool = True, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.blackman_window. This method simply + wraps the function, and so the docstring for ivy.blackman_window also applies to + this method with minimal changes. + Parameters + ---------- + window_length + container including multiple window sizes. + periodic + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + dtype + The data type to produce. Must be a floating point type. + out + optional output container, for writing the result to. + Returns + ------- + ret + The container that contains the Blackman windows. + Examples + -------- + With one :class:`ivy.Container` input: + >>> x = ivy.Container(a=3, b=5) + >>> ivy.Container.static_blackman_window(x) + { + a: ivy.array([-1.38777878e-17, 6.30000000e-01, 6.30000000e-01]) + b: ivy.array([-1.38777878e-17, 2.00770143e-01, 8.49229857e-01, + 8.49229857e-01, 2.00770143e-01]) + } + """ + return ContainerBase.cont_multi_map_in_function( + "blackman_window", + window_length, + periodic, + dtype, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def blackman_window( + self: ivy.Container, + periodic: bool = True, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + *, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.blackman_window. This method simply + wraps the function, and so the docstring for ivy.blackman_window also applies to + this method with minimal changes. + Parameters + ---------- + self + input container with window sizes. + periodic + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + dtype + The data type to produce. Must be a floating point type. + out + optional output container, for writing the result to. + Returns + ------- + ret + The container containing the Blackman windows. + Examples + -------- + With one :class:`ivy.Container` input: + >>> x = ivy.Container(a=3, b=5) + >>> ivy.blackman_window(x) + { + a: ivy.array([-1.38777878e-17, 6.30000000e-01, 6.30000000e-01]) + b: ivy.array([-1.38777878e-17, 2.00770143e-01, 8.49229857e-01, + 8.49229857e-01, 2.00770143e-01]) + } + """ + return self.static_blackman_window(self, periodic, dtype, out=out) \ No newline at end of file diff --git a/ivy/functional/backends/jax/experimental/creation.py b/ivy/functional/backends/jax/experimental/creation.py index f44f897dfd2f4..349bfda0da5cb 100644 --- a/ivy/functional/backends/jax/experimental/creation.py +++ b/ivy/functional/backends/jax/experimental/creation.py @@ -106,3 +106,23 @@ def unsorted_segment_sum( data, segment_ids, num_segments ) return jax.ops.segment_sum(data, segment_ids, num_segments) + + +def blackman_window( + size: int, + /, + *, + periodic: bool = True, + dtype: Optional[jnp.dtype] = None, + out: Optional[JaxArray] = None, +) -> JaxArray: + if size < 2: + return jnp.ones([size], dtype=dtype) + if periodic: + count = jnp.arange(size) / size + else: + count = jnp.linspace(start=0, stop=size, num=size) + return (0.42 - 0.5 * jnp.cos(2 * jnp.pi * count)) + ( + 0.08 * jnp.cos(2 * jnp.pi * 2 * count) + ) + diff --git a/ivy/functional/backends/mxnet/experimental/creation.py b/ivy/functional/backends/mxnet/experimental/creation.py index 5f833567ec121..633a0125504d0 100644 --- a/ivy/functional/backends/mxnet/experimental/creation.py +++ b/ivy/functional/backends/mxnet/experimental/creation.py @@ -51,3 +51,14 @@ def tril_indices( n_rows: int, n_cols: Optional[int] = None, k: int = 0, /, *, device: str ) -> Tuple[(Union[(None, mx.ndarray.NDArray)], ...)]: raise IvyNotImplementedException() + + +def blackman_window( + size: int, + /, + *, + periodic: bool = True, + dtype: Optional[None] = None, + out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, +) -> Union[(None, mx.ndarray.NDArray)]: + raise IvyNotImplementedException() diff --git a/ivy/functional/backends/numpy/experimental/creation.py b/ivy/functional/backends/numpy/experimental/creation.py index 73fc485b2b079..874299b3137fc 100644 --- a/ivy/functional/backends/numpy/experimental/creation.py +++ b/ivy/functional/backends/numpy/experimental/creation.py @@ -113,6 +113,29 @@ def unsorted_segment_min( return res +def blackman_window( + size: int, + /, + *, + periodic: bool = True, + dtype: Optional[np.dtype] = None, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + if size < 2: + return np.ones([size], dtype=dtype) + if periodic: + count = np.arange(size) / size + else: + count = np.linspace(start=0, stop=size, num=size) + + return ( + (0.42 - 0.5 * np.cos(2 * np.pi * count)) + + (0.08 * np.cos(2 * np.pi * 2 * count)) + ).astype(dtype) + + +blackman_window.support_native_out = False + def unsorted_segment_sum( data: np.ndarray, segment_ids: np.ndarray, @@ -134,3 +157,4 @@ def unsorted_segment_sum( res[i] = np.sum(data[mask_index], axis=0) return res + diff --git a/ivy/functional/backends/paddle/experimental/creation.py b/ivy/functional/backends/paddle/experimental/creation.py index c73f96cdf591c..91b064a675c66 100644 --- a/ivy/functional/backends/paddle/experimental/creation.py +++ b/ivy/functional/backends/paddle/experimental/creation.py @@ -133,6 +133,27 @@ def unsorted_segment_min( return res + + +def blackman_window( + size: int, + /, + *, + periodic: Optional[bool] = True, + dtype: Optional[paddle.dtype] = None, + out: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + if size < 2: + return paddle.ones([size], dtype=dtype) + if periodic: + count = paddle.arange(size) / size + else: + count = paddle.linspace(start=0, stop=size, num=size) + return ( + (0.42 - 0.5 * paddle.cos(2 * math.pi * count)) + + (0.08 * paddle.cos(2 * math.pi * 2 * count)) + ).cast(dtype) + def unsorted_segment_sum( data: paddle.Tensor, segment_ids: paddle.Tensor, @@ -165,3 +186,4 @@ def unsorted_segment_sum( res = paddle.cast(res, "int32") return res + diff --git a/ivy/functional/backends/tensorflow/experimental/creation.py b/ivy/functional/backends/tensorflow/experimental/creation.py index c1c5042d1635d..49c1df5931955 100644 --- a/ivy/functional/backends/tensorflow/experimental/creation.py +++ b/ivy/functional/backends/tensorflow/experimental/creation.py @@ -103,9 +103,29 @@ def unsorted_segment_min( return tf.math.unsorted_segment_min(data, segment_ids, num_segments) +def blackman_window( + size: int, + /, + *, + periodic: bool = True, + dtype: Optional[tf.DType] = None, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + if size < 2: + return tnp.ones([size], dtype=tnp.result_type(size, 0.0)) + if periodic: + count = tnp.arange(size) / size + else: + count = tnp.linspace(start=0, stop=size, num=size) + + return (0.42 - 0.5 * tnp.cos(2 * tnp.pi * count)) + ( + 0.08 * tnp.cos(2 * tnp.pi * 2 * count) + ) + def unsorted_segment_sum( data: tf.Tensor, segment_ids: tf.Tensor, num_segments: Union[int, tf.Tensor], ) -> tf.Tensor: return tf.math.unsorted_segment_sum(data, segment_ids, num_segments) + diff --git a/ivy/functional/backends/torch/experimental/creation.py b/ivy/functional/backends/torch/experimental/creation.py index 73a774bae0cb0..d2894c4550e39 100644 --- a/ivy/functional/backends/torch/experimental/creation.py +++ b/ivy/functional/backends/torch/experimental/creation.py @@ -152,6 +152,24 @@ def unsorted_segment_min( return res +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +def blackman_window( + size: int, + /, + *, + periodic: bool = True, + dtype: Optional[torch.dtype] = None, + out: Optional[torch.tensor] = None, +) -> torch.tensor: + return torch.blackman_window( + size, + periodic=periodic, + dtype=dtype, + ) + + +blackman_window.support_native_out = False + def unsorted_segment_sum( data: torch.Tensor, segment_ids: torch.Tensor, @@ -175,3 +193,5 @@ def unsorted_segment_sum( res[i] = torch.sum(data[mask_index], dim=0) return res + + diff --git a/ivy/functional/ivy/experimental/creation.py b/ivy/functional/ivy/experimental/creation.py index ea6a0b2448696..9844f76f6eff1 100644 --- a/ivy/functional/ivy/experimental/creation.py +++ b/ivy/functional/ivy/experimental/creation.py @@ -652,6 +652,7 @@ def unsorted_segment_min( return ivy.current_backend().unsorted_segment_min(data, segment_ids, num_segments) + @handle_exceptions @handle_nestable @to_native_arrays_and_back @@ -687,6 +688,53 @@ def unsorted_segment_sum( return ivy.current_backend().unsorted_segment_sum(data, segment_ids, num_segments) + +@handle_exceptions +@handle_nestable +@handle_out_argument +@to_native_arrays_and_back +@infer_dtype +@handle_device_shifting +def blackman_window( + size: int, + *, + periodic: bool = True, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Generate a Blackman window. The Blackman window is a taper formed by using the first + three terms of a summation of cosines. It was designed to have close to the minimal + leakage possible. It is close to optimal, only slightly worse than a Kaiser window. + Parameters + ---------- + window_length + the window_length of the returned window. + periodic + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + dtype + The data type to produce. Must be a floating point type. + out + optional output array, for writing the result to. + Returns + ------- + ret + The array containing the window. + Functional Examples + ------------------- + >>> ivy.blackman_window(4, periodic = True) + ivy.array([-1.38777878e-17, 3.40000000e-01, 1.00000000e+00, 3.40000000e-01]) + >>> ivy.blackman_window(7, periodic = False) + ivy.array([-1.38777878e-17, 1.30000000e-01, 6.30000000e-01, 1.00000000e+00, + 6.30000000e-01, 1.30000000e-01, -1.38777878e-17]) + """ + return ivy.current_backend().blackman_window( + size, periodic=periodic, dtype=dtype, out=out + ) + + + @handle_exceptions @handle_nestable @infer_dtype @@ -755,3 +803,4 @@ def random_tucker( return ivy.TuckerTensor.tucker_to_tensor((core, factors)) else: return ivy.TuckerTensor((core, factors)) + diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py index 71ef4a1b0d16a..fb14fa52401ad 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py @@ -395,6 +395,34 @@ def test_unsorted_segment_sum( ) +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_sum", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_sum( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + fw=backend_fw, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + ) + + + @st.composite def _random_tucker_data(draw): shape = draw( @@ -469,3 +497,4 @@ def test_random_tucker( for f, f_gt in zip(factors, factors_gt): assert np.prod(f.shape) == np.prod(f_gt.shape) +