Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ivy.clip function #21997

Merged
merged 32 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8529e70
Refactor `ivy.clip` function
NripeshN Aug 16, 2023
1b42810
merge_with_upstream.sh update
NripeshN Aug 16, 2023
ffc1c03
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 16, 2023
58e4f87
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 16, 2023
b52439b
Revert "Refactor `ivy.clip` function"
NripeshN Aug 17, 2023
fa78c2a
refactor jax clip
NripeshN Aug 17, 2023
4826b34
jax clip update
NripeshN Aug 17, 2023
f32e749
refactor clip for paddle and tensorflow
NripeshN Aug 17, 2023
767a185
torch clip refactor
NripeshN Aug 17, 2023
6a56f2e
🤖 Lint code
ivy-branch Aug 17, 2023
889c2c3
Update ivy/functional/backends/paddle/manipulation.py
NripeshN Aug 18, 2023
889eb54
Update ivy/functional/backends/paddle/manipulation.py
NripeshN Aug 18, 2023
9e91d49
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 18, 2023
6123ab0
update numpy hints
NripeshN Aug 18, 2023
27bd292
test manipulation
NripeshN Aug 19, 2023
8998658
update test
NripeshN Aug 21, 2023
679e6af
🤖 Lint code
ivy-branch Aug 21, 2023
8b56aa6
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 22, 2023
435fed7
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 22, 2023
783bfb3
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 24, 2023
07916c6
Update ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py
NripeshN Aug 24, 2023
fde68c1
Update ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py
NripeshN Aug 24, 2023
a4317f1
Update ivy/functional/backends/tensorflow/manipulation.py
NripeshN Aug 24, 2023
8a64392
Standardise for numpy backend
NripeshN Aug 26, 2023
5eb5aeb
small fix
NripeshN Aug 28, 2023
3822b5a
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 28, 2023
00ade6f
Merge remote-tracking branch 'upstream/main' into ivy-clip
NripeshN Aug 29, 2023
f677a05
Fianl changes
NripeshN Aug 29, 2023
69d79fe
Update ivy/functional/backends/numpy/manipulation.py
NripeshN Aug 29, 2023
1202347
final fix
NripeshN Sep 1, 2023
eb1c4a3
Merge branch 'main' into ivy-clip
NripeshN Sep 1, 2023
7536469
🤖 Lint code
ivy-branch Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions ivy/functional/backends/jax/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,19 @@ def tile(

def clip(
x: JaxArray,
x_min: Union[Number, JaxArray],
x_max: Union[Number, JaxArray],
x_min: Optional[Union[Number, JaxArray]] = None,
x_max: Optional[Union[Number, JaxArray]] = None,
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
if x_min is None and x_max is None:
raise ValueError("At least one of the x_min or x_max must be provided")

if (
hasattr(x_min, "dtype")
x_min is not None
and hasattr(x_min, "dtype")
and x_max is not None
and hasattr(x_max, "dtype")
and (x.dtype != x_min.dtype or x.dtype != x_max.dtype)
):
Expand All @@ -236,12 +241,21 @@ def clip(
promoted_type = jnp.promote_types(promoted_type, x_max.dtype)
x = x.astype(promoted_type)
else:
promoted_type = jnp.promote_types(x.dtype, x_min.dtype)
promoted_type = jnp.promote_types(promoted_type, x_max.dtype)
x.astype(promoted_type)
promoted_type = jnp.promote_types(
x.dtype, x_min.dtype if x_min is not None else x.dtype
)
promoted_type = jnp.promote_types(
promoted_type, x_max.dtype if x_max is not None else x.dtype
)
x = x.astype(promoted_type)

# jnp.clip isn't used because of inconsistent gradients
x = jnp.where(x > x_max, x_max, x)
return jnp.where(x < x_min, x_min, x)
if x_max is not None:
x = jnp.where(x > x_max, x_max, x)
if x_min is not None:
x = jnp.where(x < x_min, x_min, x)

return x


@with_unsupported_dtypes({"0.4.14 and below": ("uint64",)}, backend_version)
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/numpy/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def unstack(

def clip(
x: np.ndarray,
x_min: Union[Number, np.ndarray],
x_max: Union[Number, np.ndarray],
x_min: Union[float, int, np.ndarray] = None,
x_max: Union[float, int, np.ndarray] = None,
/,
*,
out: Optional[np.ndarray] = None,
Expand Down
13 changes: 10 additions & 3 deletions ivy/functional/backends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,20 @@ def swapaxes(

def clip(
x: paddle.Tensor,
x_min: Union[Number, paddle.Tensor],
x_max: Union[Number, paddle.Tensor],
x_min: Optional[Union[Number, paddle.Tensor]] = None,
x_max: Optional[Union[Number, paddle.Tensor]] = None,
/,
*,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return paddle_backend.minimum(paddle_backend.maximum(x, x_min), x_max)
if x_min is None and x_max is None:
raise ValueError("At least one of the x_min or x_max must be provided")

if x_min is not None:
x = paddle_backend.maximum(x, x_min)
if x_max is not None:
x = paddle_backend.minimum(x, x_max)
return x


def unstack(
Expand Down
15 changes: 11 additions & 4 deletions ivy/functional/backends/tensorflow/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,27 +336,34 @@ def swapaxes(
@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def clip(
x: Union[tf.Tensor, tf.Variable],
x_min: Union[Number, tf.Tensor, tf.Variable],
x_max: Union[Number, tf.Tensor, tf.Variable],
x_min: Optional[Union[Number, tf.Tensor, tf.Variable]] = None,
x_max: Optional[Union[Number, tf.Tensor, tf.Variable]] = None,
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if x_min is None and x_max is None:
raise ValueError("At least one of the x_min or x_max must be provided")

if hasattr(x_min, "dtype") and hasattr(x_max, "dtype"):
promoted_type = ivy.as_native_dtype(ivy.promote_types(x.dtype, x_min.dtype))
promoted_type = ivy.as_native_dtype(
ivy.promote_types(promoted_type, x_max.dtype)
)
x = tf.cast(x, promoted_type)
x_min = tf.cast(x_min, promoted_type)
x_max = tf.cast(x_max, promoted_type)
if x_min is not None:
x_min = tf.cast(x_min, promoted_type)
if x_max is not None:
x_max = tf.cast(x_max, promoted_type)

if tf.size(x) == 0:
ret = x
elif x.dtype == tf.bool:
ret = tf.clip_by_value(tf.cast(x, tf.float16), x_min, x_max)
ret = tf.cast(ret, x.dtype)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I follow why we need an else case here since x_min is not None and x_max is not None and x_min > x_max case (i.e not cond) can be handled by tf.experimental.numpy.clip as well, right? Maybe I'm missing something here..

Also, it looks like the frontend tests are failing because of cond with the following error:

E     ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

This is because x_min and x_max can also be arrays.

Copy link
Contributor

@ShreyanshBardia ShreyanshBardia Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, sorry missed the fact that they could be arrays too, I think .any should be able to resolve this. tf.experimental.numpy.clip also follows tf.clip_by_value behaviour when x_min>x_max, which we discussed in the above comments. I think the only difference is numpy namespace and the fact that it can accept None. You can verify this from the following code.

>>> print(tf.clip_by_value(tf.constant([1,10,13]),12,11))
tf.Tensor([12 12 12], shape=(3,), dtype=int32)
>>> print(tf.experimental.numpy.clip(tf.constant([1,10,13]),12,11))
tf.Tensor([12 12 12], shape=(3,), dtype=int32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, sorry missed the fact that they could be arrays too, I think .any should be able to resolve this. tf.experimental.numpy.clip also follows tf.clip_by_value behaviour when x_min>x_max, which we discussed in the above comments. I think the only difference is numpy namespace and the fact that it can accept None. You can verify this from the following code.

>>> print(tf.clip_by_value(tf.constant([1,10,13]),12,11))
tf.Tensor([12 12 12], shape=(3,), dtype=int32)
>>> print(tf.experimental.numpy.clip(tf.constant([1,10,13]),12,11))
tf.Tensor([12 12 12], shape=(3,), dtype=int32)

Ah, that's interesting. I expected tf.experimental.numpy.clip to behave exactly the same as np.clip.

ret = tf.clip_by_value(x, x_min, x_max)
NripeshN marked this conversation as resolved.
Show resolved Hide resolved

return ret


Expand Down
19 changes: 13 additions & 6 deletions ivy/functional/backends/torch/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,28 @@ def swapaxes(
)
def clip(
x: torch.Tensor,
x_min: Union[Number, torch.Tensor],
x_max: Union[Number, torch.Tensor],
x_min: Optional[Union[Number, torch.Tensor]] = None,
x_max: Optional[Union[Number, torch.Tensor]] = None,
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(x_min, "dtype"):
x_min = torch.asarray(x_min, device=x.device)
x_max = torch.asarray(x_max, device=x.device)
if x_min is None and x_max is None:
raise ValueError("At least one of the x_min or x_max must be provided")

if x_min is not None and hasattr(x_min, "dtype"):
x_min = torch.as_tensor(x_min, device=x.device)
if x_max is not None and hasattr(x_max, "dtype"):
x_max = torch.as_tensor(x_max, device=x.device)

if x_min is not None and x_max is not None:
promoted_type = torch.promote_types(x_min.dtype, x_max.dtype)
promoted_type = torch.promote_types(promoted_type, x.dtype)
x_min = x_min.to(promoted_type)
x_max = x_max.to(promoted_type)
x = x.to(promoted_type)
return torch.clamp(x, x_min, x_max, out=out)

return torch.clamp(x, min=x_min, max=x_max, out=out)


clip.support_native_out = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Collection of tests for manipulation functions."""

# global

import numpy as np
from hypothesis import strategies as st, assume

Expand Down Expand Up @@ -340,9 +339,16 @@ def _basic_min_x_max(draw):
available_dtypes=helpers.get_dtypes("numeric"),
)
)
min_val = draw(helpers.array_values(dtype=dtype[0], shape=()))
min_val = draw(
st.one_of(st.just(None), helpers.array_values(dtype=dtype[0], shape=()))
)
NripeshN marked this conversation as resolved.
Show resolved Hide resolved
max_val = draw(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update min_val and max_val such that these can be arrays as well? They can be of any shape that is broadcastable to the shape of the input x. You should use the broadcast helper functions in the Array helpers section here --> :https://unify.ai/docs/ivy/docs/helpers/ivy_tests.test_ivy.helpers.hypothesis_helpers/ivy_tests.test_ivy.helpers.hypothesis_helpers.array_helpers.html#ivy_tests.test_ivy.helpers.hypothesis_helpers.array_helpers.mutually_broadcastable_shapes

helpers.array_values(dtype=dtype[0], shape=()).filter(lambda x: x > min_val)
st.one_of(
st.just(None),
helpers.array_values(dtype=dtype[0], shape=()).filter(
lambda x: x > min_val
),
)
NripeshN marked this conversation as resolved.
Show resolved Hide resolved
)
return [dtype], (value[0], min_val, max_val)

Expand Down
Loading