diff --git a/ivy/functional/backends/jax/manipulation.py b/ivy/functional/backends/jax/manipulation.py index 61e638b9509cb..804774fc0c5c9 100644 --- a/ivy/functional/backends/jax/manipulation.py +++ b/ivy/functional/backends/jax/manipulation.py @@ -203,45 +203,28 @@ def tile( def clip( x: JaxArray, - x_min: Union[Number, JaxArray], - x_max: Union[Number, JaxArray], + x_min: Optional[Union[Number, JaxArray]] = None, + x_max: Optional[Union[Number, JaxArray]] = None, /, *, out: Optional[JaxArray] = None, ) -> JaxArray: - if ( - hasattr(x_min, "dtype") - and hasattr(x_max, "dtype") - and (x.dtype != x_min.dtype or x.dtype != x_max.dtype) - ): - if (jnp.float16 in (x.dtype, x_min.dtype, x_max.dtype)) and ( - jnp.int16 in (x.dtype, x_min.dtype, x_max.dtype) - or jnp.uint16 in (x.dtype, x_min.dtype, x_max.dtype) - ): - promoted_type = jnp.promote_types(x.dtype, jnp.float32) - promoted_type = jnp.promote_types(promoted_type, x_min.dtype) - promoted_type = jnp.promote_types(promoted_type, x_max.dtype) - x = x.astype(promoted_type) - elif ( - jnp.float16 in (x.dtype, x_min.dtype, x_max.dtype) - or jnp.float32 in (x.dtype, x_min.dtype, x_max.dtype) - ) and ( - jnp.int32 in (x.dtype, x_min.dtype, x_max.dtype) - or jnp.uint32 in (x.dtype, x_min.dtype, x_max.dtype) - or jnp.uint64 in (x.dtype, x_min.dtype, x_max.dtype) - or jnp.int64 in (x.dtype, x_min.dtype, x_max.dtype) - ): - promoted_type = jnp.promote_types(x.dtype, jnp.float64) - promoted_type = jnp.promote_types(promoted_type, x_min.dtype) - promoted_type = jnp.promote_types(promoted_type, x_max.dtype) - x = x.astype(promoted_type) - else: - promoted_type = jnp.promote_types(x.dtype, x_min.dtype) - promoted_type = jnp.promote_types(promoted_type, x_max.dtype) - x.astype(promoted_type) - # jnp.clip isn't used because of inconsistent gradients - x = jnp.where(x > x_max, x_max, x) - return jnp.where(x < x_min, x_min, x) + if x_min is None and x_max is None: + raise ValueError("At least one of the x_min or x_max must be provided") + promoted_type = x.dtype + if x_min is not None: + if not hasattr(x_min, "dtype"): + x_min = ivy.array(x_min).data + promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype)) + x = jnp.where(x < x_min, x_min.astype(promoted_type), x.astype(promoted_type)) + if x_max is not None: + if not hasattr(x_max, "dtype"): + x_max = ivy.array(x_max).data + promoted_type = ivy.as_native_dtype( + ivy.promote_types(promoted_type, x_max.dtype) + ) + x = jnp.where(x > x_max, x_max.astype(promoted_type), x.astype(promoted_type)) + return x @with_unsupported_dtypes({"0.4.14 and below": ("uint64",)}, backend_version) diff --git a/ivy/functional/backends/numpy/manipulation.py b/ivy/functional/backends/numpy/manipulation.py index f258de2b5c59d..dd5a0913a2549 100644 --- a/ivy/functional/backends/numpy/manipulation.py +++ b/ivy/functional/backends/numpy/manipulation.py @@ -267,7 +267,18 @@ def clip( *, out: Optional[np.ndarray] = None, ) -> np.ndarray: - return np.asarray(np.clip(x, x_min, x_max, out=out), dtype=x.dtype) + promoted_type = x.dtype + if x_min is not None: + if not hasattr(x_min, "dtype"): + x_min = ivy.array(x_min).data + promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype)) + if x_max is not None: + if not hasattr(x_max, "dtype"): + x_max = ivy.array(x_max).data + promoted_type = ivy.as_native_dtype( + ivy.promote_types(promoted_type, x_max.dtype) + ) + return np.clip(x.astype(promoted_type), x_min, x_max, out=out) clip.support_native_out = True diff --git a/ivy/functional/backends/paddle/manipulation.py b/ivy/functional/backends/paddle/manipulation.py index 4da5a1dbd9b00..bf654b3f253a4 100644 --- a/ivy/functional/backends/paddle/manipulation.py +++ b/ivy/functional/backends/paddle/manipulation.py @@ -444,7 +444,26 @@ def clip( *, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - return paddle_backend.minimum(paddle_backend.maximum(x, x_min), x_max) + if x_min is None and x_max is None: + raise ValueError("At least one of the x_min or x_max must be provided") + promoted_type = x.dtype + if x_min is not None: + if not hasattr(x_min, "dtype"): + x_min = ivy.array(x_min).data + promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype)) + x = paddle_backend.maximum( + paddle.cast(x, promoted_type), paddle.cast(x_min, promoted_type) + ) + if x_max is not None: + if not hasattr(x_max, "dtype"): + x_max = ivy.array(x_max).data + promoted_type = ivy.as_native_dtype( + ivy.promote_types(promoted_type, x_max.dtype) + ) + x = paddle_backend.minimum( + paddle.cast(x, promoted_type), paddle.cast(x_max, promoted_type) + ) + return x def unstack( diff --git a/ivy/functional/backends/tensorflow/manipulation.py b/ivy/functional/backends/tensorflow/manipulation.py index 41800bca81d4f..7142dc34f1808 100644 --- a/ivy/functional/backends/tensorflow/manipulation.py +++ b/ivy/functional/backends/tensorflow/manipulation.py @@ -337,22 +337,33 @@ def clip( *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: - if hasattr(x_min, "dtype") and hasattr(x_max, "dtype"): + if x_min is None and x_max is None: + raise ValueError("At least one of the x_min or x_max must be provided") + promoted_type = x.dtype + if x_min is not None: + if not hasattr(x_min, "dtype"): + x_min = ivy.array(x_min).data promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype)) + if x_max is not None: + if not hasattr(x_max, "dtype"): + x_max = ivy.array(x_max).data promoted_type = ivy.as_native_dtype( ivy.promote_types(promoted_type, x_max.dtype) ) - x = tf.cast(x, promoted_type) - x_min = tf.cast(x_min, promoted_type) x_max = tf.cast(x_max, promoted_type) - if tf.size(x) == 0: - ret = x - elif x.dtype == tf.bool: - ret = tf.clip_by_value(tf.cast(x, tf.float16), x_min, x_max) - ret = tf.cast(ret, x.dtype) + x = tf.cast(x, promoted_type) + if x_min is not None: + x_min = tf.cast(x_min, promoted_type) + cond = True + if x_min is not None and x_max is not None: + if tf.math.reduce_any(tf.experimental.numpy.greater(x_min, x_max)): + cond = False + if cond: + return tf.experimental.numpy.clip(x, x_min, x_max) else: - ret = tf.clip_by_value(x, x_min, x_max) - return ret + return tf.experimental.numpy.minimum( + x_max, tf.experimental.numpy.maximum(x, x_min) + ) def unstack( diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index fe8b37b4dcafd..b0ea41b614fcf 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -324,14 +324,21 @@ def clip( *, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if hasattr(x_min, "dtype"): - x_min = torch.asarray(x_min, device=x.device) - x_max = torch.asarray(x_max, device=x.device) - promoted_type = torch.promote_types(x_min.dtype, x_max.dtype) - promoted_type = torch.promote_types(promoted_type, x.dtype) - x_min = x_min.to(promoted_type) + promoted_type = x.dtype + if x_min is not None: + if not hasattr(x_min, "dtype"): + x_min = ivy.array(x_min).data + promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype)) + if x_max is not None: + if not hasattr(x_max, "dtype"): + x_max = ivy.array(x_max).data + promoted_type = ivy.as_native_dtype( + ivy.promote_types(promoted_type, x_max.dtype) + ) x_max = x_max.to(promoted_type) - x = x.to(promoted_type) + x = x.to(promoted_type) + if x_min is not None: + x_min = x_min.to(promoted_type) return torch.clamp(x, x_min, x_max, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py index f3564e10e7317..6ae42c6de5ca6 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py @@ -1,7 +1,6 @@ """Collection of tests for manipulation functions.""" # global - import numpy as np from hypothesis import strategies as st, assume @@ -67,6 +66,61 @@ def _basic_min_x_max(draw): return [dtype], (value[0], min_val, max_val) +@st.composite +def _broadcastable_arrays(draw): + shapes = draw(helpers.mutually_broadcastable_shapes(num_shapes=3)) + dtypes, values = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=shapes[0] + ) + ) + min_val = draw( + st.one_of( + st.floats(-5, 5), + st.just(None), + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=shapes[1] + ), + ) + ) + max_val = draw( + st.one_of( + st.floats(-5, 5), + st.just(None), + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=shapes[2] + ), + ) + ) + if min_val is None and max_val is None: + generate_max = draw(st.booleans()) + if generate_max: + max_val = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=shapes[2] + ) + ) + else: + min_val = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=shapes[1] + ) + ) + if min_val is not None: + if not isinstance(min_val, float): + dtypes.append(min_val[0][0]) + min_val = min_val[1][0] + else: + dtypes.append(ivy.float32) + if max_val is not None: + if not isinstance(max_val, float): + dtypes.append(max_val[0][0]) + max_val = max_val[1][0] + else: + dtypes.append(ivy.float32) + return dtypes, values[0], min_val, max_val + + @st.composite def _constant_pad_helper(draw): dtype, value, shape = draw( @@ -224,12 +278,12 @@ def _stack_helper(draw): # clip @handle_test( fn_tree="functional.ivy.clip", - dtype_x_min_max=_basic_min_x_max(), + dtype_x_min_max=_broadcastable_arrays(), ) def test_clip(*, dtype_x_min_max, test_flags, backend_fw, fn_name, on_device): - dtypes, (x_list, min_val, max_val) = dtype_x_min_max + dtypes, x_list, min_val, max_val = dtype_x_min_max helpers.test_function( - input_dtypes=dtypes[0], + input_dtypes=dtypes, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, @@ -582,6 +636,28 @@ def test_squeeze(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_devic ) +@handle_test( + fn_tree="functional.ivy.stack", + dtypes_arrays=_stack_helper(), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="values_shape"), + force_int=True, + ), +) +def test_stack(*, dtypes_arrays, axis, test_flags, backend_fw, fn_name, on_device): + dtypes, arrays = dtypes_arrays + + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + arrays=arrays, + axis=axis, + ) + + # stack @handle_test( fn_tree="functional.ivy.stack",