diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index ae7d7d8b97c42..881c0f4b537c4 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -490,3 +490,43 @@ def scaled_tanh( ivy.array([0.1, 0.1, 0.1]) """ return ivy.scaled_tanh(self._data, alpha=alpha, beta=beta, out=out) + + def hardshrink( + self: ivy.Array, + /, + *, + lambd: float = 0.5, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.hardshrink. This method simply wraps + the function, and so the docstring for ivy.hardshrink also applies to this + method with minimal changes. + + Parameters + ---------- + self + input array. + lambd + the lambd value for the Hardshrink formulation + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + an array with the hardshrink activation function applied element-wise. + + Examples + -------- + >>> x = ivy.array([-1., 0., 1.]) + >>> y = x.hardshrink() + >>> print(y) + ivy.array([-1., 0., 1.]) + >>> x = ivy.array([-1., 0., 1.]) + >>> y = x.hardshrink(lambd=1.0) + >>> print(y) + ivy.array([0., 0., 0.]) + """ + return ivy.hardshrink(self._data, lambd=lambd, out=out) diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 59a7aab8dd906..c6a7cb689379e 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -1613,3 +1613,121 @@ def scaled_tanh( map_sequences=map_sequences, out=out, ) + + @staticmethod + def _static_hardshrink( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + lambd: ivy.Container = 0.5, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = False, + prune_unapplied: Union[bool, ivy.Container] = True, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.hardshrink. This method simply wraps + the function, and so the docstring for ivy.hardshrink also applies to this + method with minimal changes. + + Parameters + ---------- + x + input container. + lambd + Lambda value for hard shrinkage calculation. + key_chains + The key-chains to apply or not apply the method to. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + map_sequences + Whether to also map method to sequences (lists, tuples). + + Returns + ------- + ret + Container with hard shrinkage applied to the leaves. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([1., -2.]), b=ivy.array([0.4, -0.2])) + >>> y = ivy.Container._static_hardshrink(x) + >>> print(y) + { + a: ivy.array([1., -2.]), + b: ivy.array([0., 0.]) + } + """ + return ContainerBase.cont_multi_map_in_function( + "hardshrink", + x, + lambd=lambd, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def hardshrink( + self: ivy.Container, + /, + *, + lambd: ivy.Container = 0.5, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = False, + prune_unapplied: Union[bool, ivy.Container] = True, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + Apply the hard shrinkage function element-wise. + + Parameters + ---------- + self + Input container. + lambd + Lambda value for hard shrinkage calculation. + key_chains + The key-chains to apply or not apply the method to. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + map_sequences + Whether to also map method to sequences (lists, tuples). + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + Container with hard shrinkage applied to the leaves. + + Examples + -------- + >>> import ivy.numpy as np + >>> x = ivy.Container(a=np.array([1., -2.]), b=np.array([0.4, -0.2])) + >>> y = ivy.Container.hardshrink(x) + >>> print(y) + { + a: ivy.array([1., -2.]), + b: ivy.array([0., 0.]) + } + """ + return self._static_hardshrink( + self, + lambd=lambd, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index c8234ea93ea7b..c21f9c034e369 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -135,3 +135,13 @@ def scaled_tanh( out: Optional[JaxArray] = None, ) -> JaxArray: return alpha * jax.nn.tanh(beta * x) + + +@with_unsupported_dtypes({"0.4.16 and below": ("float16", "bfloat16")}, backend_version) +def hardshrink( + x: JaxArray, /, *, lambd: float = 0.5, out: Optional[JaxArray] = None +) -> JaxArray: + ret = jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0)) + if ivy.exists(out): + return ivy.inplace_update(out, ret).astype(x.dtype) + return ret diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 5070d2f09e327..b84c15cb7a062 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -171,3 +171,16 @@ def scaled_tanh( out: Optional[np.ndarray] = None, ) -> np.ndarray: return alpha * np.tanh(beta * x) + + +@_scalar_output_to_0d_array +def hardshrink( + x: np.ndarray, /, *, lambd: float = 0.5, out: Optional[np.ndarray] = None +) -> np.ndarray: + ret = np.where(x > lambd, x, np.where(x < -lambd, x, 0)) + if ivy.exists(out): + return ivy.inplace_update(out, ret).astype(x.dtype) + return ivy.astype(ret, x.dtype) + + +hardshrink.support_native_out = True diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index c45e8f9c2a995..d414972706b16 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -214,3 +214,20 @@ def scaled_tanh( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: return paddle.stanh(x, scale_a=beta, scale_b=alpha) + + +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, + backend_version, +) +def hardshrink( + x: paddle.Tensor, /, *, lambd: float = 0.5, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + if x.dtype in [paddle.float32, paddle.float64]: + return F.hardshrink(x, threshold=lambd) + if paddle.is_complex(x): + return paddle.complex( + F.hardshrink(x.real(), threshold=lambd), + F.hardshrink(x.img(), threshold=lambd), + ) + return F.hardshrink(x.cast("float32"), threshold=lambd).cast(x.dtype) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 401a363eb5cf2..8427d803331f5 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -153,3 +153,21 @@ def scaled_tanh( out: Optional[Tensor] = None, ) -> Tensor: return alpha * tf.nn.tanh(beta * x) + + +@with_supported_dtypes({"2.14.0 and below": ("float",)}, backend_version) +def hardshrink( + x: Tensor, + /, + *, + lambd: float = 0.5, + out: Optional[Tensor] = None, +) -> Tensor: + ret = tf.where( + tf.math.greater(x, lambd), + x, + tf.where(tf.math.less(x, -lambd), x, 0), + ) + if ivy.exists(out): + return ivy.inplace_update(out, ret).astype(x.dtype) + return ivy.astype(ret, x.dtype) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 12666ddd3ae74..34b51951f75c1 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -137,3 +137,13 @@ def scaled_tanh( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return alpha * torch.nn.functional.tanh(beta * x) + + +@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version) +def hardshrink( + x: torch.Tensor, /, *, lambd: float = 0.5, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + ret = torch.nn.functional.hardshrink(x, lambd=lambd) + if ivy.exists(out): + return ivy.inplace_update(out, ret).astype(x.dtype) + return ivy.astype(ret, x.dtype) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 950036652a10f..c78f44198b9d6 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -866,3 +866,55 @@ def scaled_tanh( stanh = scaled_tanh + + +@handle_exceptions +@handle_backend_invalid +@handle_nestable +@handle_array_like_without_promotion +@handle_out_argument +@to_native_arrays_and_back +@handle_array_function +def hardshrink( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + lambd: float = 0.5, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Apply the hardshrink function element-wise. + + Parameters + ---------- + x + input array. + lambd + the value for the Hardshrink formulation. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. + + Returns + ------- + ret + an array containing the hardshrink activation of each element in ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + >>> x = ivy.array([-1.0, 1.0, 2.0]) + >>> y = ivy.hardshrink(x) + >>> print(y) + ivy.array([-1., 1., 2.]) + >>> x = ivy.array([-1.0, 1.0, 2.0]) + >>> y = x.hardshrink() + >>> print(y) + ivy.array([-0.5, 0.5, 1.5]) + >>> x = ivy.array([[-1.3, 3.8, 2.1], [1.7, 4.2, -6.6]]) + >>> y = ivy.hardshrink(x) + >>> print(y) + ivy.array([[-1.29999995, 3.79999995, 2.0999999 ], + [ 1.70000005, 4.19999981, -6.5999999 ]]) + """ + return current_backend(x).hardshrink(x, lambd=lambd, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index bcd63fe245674..ab60d709d92ed 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -70,6 +70,34 @@ def test_elu( ) +# hardshrink +@handle_test( + fn_tree="functional.ivy.experimental.hardshrink", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), + threshold=st.one_of( + st.floats(min_value=0.0, max_value=1e30), + ), +) +def test_hardshrink( + *, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device +): + dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + x=x[0], + lambd=threshold, + ) + + # hardtanh @handle_test( fn_tree="functional.ivy.experimental.hardtanh",