Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : (adeded the hardshrink to ivy experimental) #27012

Merged
merged 8 commits into from
Oct 24, 2023
40 changes: 40 additions & 0 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,43 @@ def scaled_tanh(
ivy.array([0.1, 0.1, 0.1])
"""
return ivy.scaled_tanh(self._data, alpha=alpha, beta=beta, out=out)

def hardshrink(
self: ivy.Array,
/,
*,
lambd: float = 0.5,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.hardshrink. This method simply wraps
the function, and so the docstring for ivy.hardshrink also applies to this
method with minimal changes.

Parameters
----------
self
input array.
lambd
the lambd value for the Hardshrink formulation
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.

Returns
-------
ret
an array with the hardshrink activation function applied element-wise.

Examples
--------
>>> x = ivy.array([-1., 0., 1.])
>>> y = x.hardshrink()
>>> print(y)
ivy.array([-1., 0., 1.])
>>> x = ivy.array([-1., 0., 1.])
>>> y = x.hardshrink(lambd=1.0)
>>> print(y)
ivy.array([0., 0., 0.])
"""
return ivy.hardshrink(self._data, lambd=lambd, out=out)
118 changes: 118 additions & 0 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,3 +1613,121 @@ def scaled_tanh(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def _static_hardshrink(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
lambd: ivy.Container = 0.5,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = False,
prune_unapplied: Union[bool, ivy.Container] = True,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.hardshrink. This method simply wraps
the function, and so the docstring for ivy.hardshrink also applies to this
method with minimal changes.

Parameters
----------
x
input container.
lambd
Lambda value for hard shrinkage calculation.
key_chains
The key-chains to apply or not apply the method to.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
map_sequences
Whether to also map method to sequences (lists, tuples).

Returns
-------
ret
Container with hard shrinkage applied to the leaves.

Examples
--------
>>> x = ivy.Container(a=ivy.array([1., -2.]), b=ivy.array([0.4, -0.2]))
>>> y = ivy.Container._static_hardshrink(x)
>>> print(y)
{
a: ivy.array([1., -2.]),
b: ivy.array([0., 0.])
}
"""
return ContainerBase.cont_multi_map_in_function(
"hardshrink",
x,
lambd=lambd,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def hardshrink(
self: ivy.Container,
/,
*,
lambd: ivy.Container = 0.5,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = False,
prune_unapplied: Union[bool, ivy.Container] = True,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Apply the hard shrinkage function element-wise.

Parameters
----------
self
Input container.
lambd
Lambda value for hard shrinkage calculation.
key_chains
The key-chains to apply or not apply the method to.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
map_sequences
Whether to also map method to sequences (lists, tuples).
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.

Returns
-------
ret
Container with hard shrinkage applied to the leaves.

Examples
--------
>>> import ivy.numpy as np
>>> x = ivy.Container(a=np.array([1., -2.]), b=np.array([0.4, -0.2]))
>>> y = ivy.Container.hardshrink(x)
>>> print(y)
{
a: ivy.array([1., -2.]),
b: ivy.array([0., 0.])
}
"""
return self._static_hardshrink(
self,
lambd=lambd,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
10 changes: 10 additions & 0 deletions ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,13 @@ def scaled_tanh(
out: Optional[JaxArray] = None,
) -> JaxArray:
return alpha * jax.nn.tanh(beta * x)


@with_unsupported_dtypes({"0.4.16 and below": ("float16", "bfloat16")}, backend_version)
def hardshrink(
x: JaxArray, /, *, lambd: float = 0.5, out: Optional[JaxArray] = None
) -> JaxArray:
ret = jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return ret
13 changes: 13 additions & 0 deletions ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,16 @@ def scaled_tanh(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return alpha * np.tanh(beta * x)


@_scalar_output_to_0d_array
def hardshrink(
x: np.ndarray, /, *, lambd: float = 0.5, out: Optional[np.ndarray] = None
) -> np.ndarray:
ret = np.where(x > lambd, x, np.where(x < -lambd, x, 0))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return ivy.astype(ret, x.dtype)


hardshrink.support_native_out = True
17 changes: 17 additions & 0 deletions ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,20 @@ def scaled_tanh(
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return paddle.stanh(x, scale_a=beta, scale_b=alpha)


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("float16", "bfloat16")}},
backend_version,
)
def hardshrink(
x: paddle.Tensor, /, *, lambd: float = 0.5, out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.hardshrink(x, threshold=lambd)
if paddle.is_complex(x):
return paddle.complex(
F.hardshrink(x.real(), threshold=lambd),
F.hardshrink(x.img(), threshold=lambd),
)
return F.hardshrink(x.cast("float32"), threshold=lambd).cast(x.dtype)
18 changes: 18 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,21 @@ def scaled_tanh(
out: Optional[Tensor] = None,
) -> Tensor:
return alpha * tf.nn.tanh(beta * x)


@with_supported_dtypes({"2.14.0 and below": ("float",)}, backend_version)
def hardshrink(
x: Tensor,
/,
*,
lambd: float = 0.5,
out: Optional[Tensor] = None,
) -> Tensor:
ret = tf.where(
tf.math.greater(x, lambd),
x,
tf.where(tf.math.less(x, -lambd), x, 0),
)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return ivy.astype(ret, x.dtype)
10 changes: 10 additions & 0 deletions ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,13 @@ def scaled_tanh(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return alpha * torch.nn.functional.tanh(beta * x)


@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version)
def hardshrink(
x: torch.Tensor, /, *, lambd: float = 0.5, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
ret = torch.nn.functional.hardshrink(x, lambd=lambd)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return ivy.astype(ret, x.dtype)
52 changes: 52 additions & 0 deletions ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,55 @@ def scaled_tanh(


stanh = scaled_tanh


@handle_exceptions
@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
def hardshrink(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
lambd: float = 0.5,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the hardshrink function element-wise.

Parameters
----------
x
input array.
lambd
the value for the Hardshrink formulation.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.

Returns
-------
ret
an array containing the hardshrink activation of each element in ``x``.

Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([-1.0, 1.0, 2.0])
>>> y = ivy.hardshrink(x)
>>> print(y)
ivy.array([-1., 1., 2.])
>>> x = ivy.array([-1.0, 1.0, 2.0])
>>> y = x.hardshrink()
>>> print(y)
ivy.array([-0.5, 0.5, 1.5])
>>> x = ivy.array([[-1.3, 3.8, 2.1], [1.7, 4.2, -6.6]])
>>> y = ivy.hardshrink(x)
>>> print(y)
ivy.array([[-1.29999995, 3.79999995, 2.0999999 ],
[ 1.70000005, 4.19999981, -6.5999999 ]])
"""
return current_backend(x).hardshrink(x, lambd=lambd, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ def test_elu(
)


# hardshrink
@handle_test(
fn_tree="functional.ivy.experimental.hardshrink",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
),
threshold=st.one_of(
st.floats(min_value=0.0, max_value=1e30),
),
)
def test_hardshrink(
*, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device
):
dtype, x = dtype_and_x
helpers.test_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_name=fn_name,
on_device=on_device,
x=x[0],
lambd=threshold,
)


# hardtanh
@handle_test(
fn_tree="functional.ivy.experimental.hardtanh",
Expand Down
Loading