Skip to content

Commit

Permalink
feat: added support for complex dtype in hardswish (ivy-llc#23609)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosesdaudu001 authored Sep 15, 2023
1 parent f482c18 commit 4b53291
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 13 deletions.
13 changes: 11 additions & 2 deletions ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,23 @@ def mish(
"""
return ivy.mish(self._data, complex_mode=complex_mode, out=out)

def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def hardswish(
self: ivy.Array,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the hardswish activation function element-wise.
Parameters
----------
x
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have
a shape that the inputs broadcast to.
Expand Down Expand Up @@ -385,4 +394,4 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr
b: ivy.array([0., 5.])
}
"""
return ivy.hardswish(self._data, out=out)
return ivy.hardswish(self._data, complex_mode=complex_mode, out=out)
10 changes: 10 additions & 0 deletions ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ def _static_hardswish(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -1142,6 +1143,9 @@ def _static_hardswish(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -1169,6 +1173,7 @@ def _static_hardswish(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -1180,6 +1185,7 @@ def hardswish(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -1202,6 +1208,9 @@ def hardswish(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand All @@ -1228,5 +1237,6 @@ def hardswish(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)
8 changes: 7 additions & 1 deletion ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,11 @@ def mish(
return x * jnp.tanh(jax.nn.softplus(x))


def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def hardswish(
x: JaxArray,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[JaxArray] = None,
) -> JaxArray:
return jax.nn.hard_swish(x)
8 changes: 7 additions & 1 deletion ivy/functional/backends/numpy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ def mish(


@_scalar_output_to_0d_array
def hardswish(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def hardswish(
x: np.ndarray,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[np.ndarray] = None,
) -> np.ndarray:
max_x_3 = np.maximum(x + 3, 0, dtype=x.dtype)
return (x * np.minimum(max_x_3, 6, out=out, dtype=x.dtype) / 6).astype(x.dtype)

Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def mish(
{"2.5.1 and below": {"cpu": ("float16",)}}, backend_version
)
def hardswish(
x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None
x: paddle.Tensor,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return F.hardswish(x)
8 changes: 7 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,11 @@ def mish(


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def hardswish(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
def hardswish(
x: Tensor,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[Tensor] = None,
) -> Tensor:
return x * tf.nn.relu6(x + 3) / 6
6 changes: 5 additions & 1 deletion ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def mish(
backend_version,
)
def hardswish(
x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None
x: torch.Tensor,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.nn.functional.hardswish(x)
26 changes: 25 additions & 1 deletion ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,15 +731,33 @@ def mish(
return current_backend(x).mish(x, out=out)


def _hardswish_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
def hard_sigmoid(x):
return ivy.relu6(x + 3.0) / 6

return ivy.multiply(x, hard_sigmoid(x).astype(x.dtype))


@handle_exceptions
@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_complex_input
def hardswish(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the hardswish activation function element-wise.
Expand All @@ -748,6 +766,9 @@ def hardswish(
----------
x
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down Expand Up @@ -777,3 +798,6 @@ def hardswish(
}
"""
return current_backend(x).hardswish(x, out=out)


hardswish.jax_like = _hardswish_jax_like
15 changes: 12 additions & 3 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,17 @@ def _forward(self, x):


class Hardswish(Module):
def __init__(self):
"""Apply the HARDSWISH activation function."""
def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"):
"""
Apply the HARDSWISH activation function.
Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""
self._complex_mode = complex_mode
Module.__init__(self)

def _forward(self, x):
Expand All @@ -402,7 +411,7 @@ def _forward(self, x):
ret
The outputs following the HARDSWISH activation *[batch_shape, d]*
"""
return ivy.hardswish(x)
return ivy.hardswish(x, complex_mode=self._complex_mode)


class Logit(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_jax_hard_silu(
@handle_frontend_test(
fn_tree="jax.nn.hard_swish",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("float_and_complex"),
min_value=-10,
max_value=10,
safety_factor_scale="linear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,17 @@ def test_gelu(
@handle_test(
fn_tree="functional.ivy.hardswish",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("float_and_complex"),
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
),
complex_mode=st.sampled_from(["jax", "split", "magnitude"]),
)
def test_hardswish(
*,
dtype_and_x,
complex_mode,
test_flags,
backend_fw,
fn_name,
Expand All @@ -70,6 +72,7 @@ def test_hardswish(
fn_name=fn_name,
on_device=on_device,
x=x[0],
complex_mode=complex_mode,
)


Expand Down

0 comments on commit 4b53291

Please sign in to comment.