Skip to content

Commit

Permalink
clip permits None for min and max (#1505)
Browse files Browse the repository at this point in the history
* Fixes `dpt.copy` returning TypeError instead of raising

When provided a non-usm_ndarray-input to copy, copy would return the error instead of raising it

* Permits clip arguments `min` and `max` to both be `None`

Also resolves gh-1489

* Specify that Python scalars are permitted for `max` and `min` in `clip`

* Adds tests to `test_tensor_clip.py` improve `_clip.py` coverage
  • Loading branch information
ndgrigorian authored Jan 25, 2024
1 parent 8ed8ef2 commit 8f82fe1
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 21 deletions.
80 changes: 67 additions & 13 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
return dpt.dtype(ti.default_device_int_type(dev))
if isinstance(dtype, WeakComplexType):
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
return st_dtype, dpt.complex64
return dpt.complex64
return _to_device_supported_dtype(dpt.complex128, dev)
return (_to_device_supported_dtype(dpt.float64, dev),)
return _to_device_supported_dtype(dpt.float64, dev)
else:
return st_dtype
else:
Expand All @@ -197,8 +197,6 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):


def _clip_none(x, val, out, order, _binary_fn):
if order not in ["K", "C", "F", "A"]:
order = "K"
q1, x_usm_type = x.sycl_queue, x.usm_type
q2, val_usm_type = _get_queue_usm_type(val)
if q2 is None:
Expand Down Expand Up @@ -391,9 +389,8 @@ def _clip_none(x, val, out, order, _binary_fn):
return out


# need to handle logic for min or max being None
def clip(x, min=None, max=None, out=None, order="K"):
"""clip(x, min, max, out=None, order="K")
def clip(x, /, min=None, max=None, out=None, order="K"):
"""clip(x, min=None, max=None, out=None, order="K")
Clips to the range [`min_i`, `max_i`] for each element `x_i`
in `x`.
Expand All @@ -402,14 +399,14 @@ def clip(x, min=None, max=None, out=None, order="K"):
x (usm_ndarray): Array containing elements to clip.
Must be compatible with `min` and `max` according
to broadcasting rules.
min ({None, usm_ndarray}, optional): Array containing minimum values.
min ({None, Union[usm_ndarray, bool, int, float, complex]}, optional):
Array containing minimum values.
Must be compatible with `x` and `max` according
to broadcasting rules.
Only one of `min` and `max` can be `None`.
max ({None, usm_ndarray}, optional): Array containing maximum values.
max ({None, Union[usm_ndarray, bool, int, float, complex]}, optional):
Array containing maximum values.
Must be compatible with `x` and `min` according
to broadcasting rules.
Only one of `min` and `max` can be `None`.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array must have the correct shape and the expected data type.
Expand All @@ -428,10 +425,67 @@ def clip(x, min=None, max=None, out=None, order="K"):
"Expected `x` to be of dpctl.tensor.usm_ndarray type, got "
f"{type(x)}"
)
if order not in ["K", "C", "F", "A"]:
order = "K"
if min is None and max is None:
raise ValueError(
"only one of `min` and `max` is permitted to be `None`"
exec_q = x.sycl_queue
orig_out = out
if out is not None:
if not isinstance(out, dpt.usm_ndarray):
raise TypeError(
"output array must be of usm_ndarray type, got "
f"{type(out)}"
)

if out.shape != x.shape:
raise ValueError(
"The shape of input and output arrays are "
f"inconsistent. Expected output shape is {x.shape}, "
f"got {out.shape}"
)

if x.dtype != out.dtype:
raise ValueError(
f"Output array of type {x.dtype} is needed, "
f"got {out.dtype}"
)

if (
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
is None
):
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

if ti._array_overlap(x, out):
if not ti._same_logical_tensors(x, out):
out = dpt.empty_like(out)
else:
return out
else:
if order == "K":
out = _empty_like_orderK(x, x.dtype)
else:
if order == "A":
order = "F" if x.flags.f_contiguous else "C"
out = dpt.empty_like(x, order=order)

ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=x, dst=out, sycl_queue=exec_q
)
if not (orig_out is None or orig_out is out):
# Copy the out data from temporary buffer to original memory
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=out,
dst=orig_out,
sycl_queue=exec_q,
depends=[copy_ev],
)
ht_copy_out_ev.wait()
out = orig_out
ht_copy_ev.wait()
return out
elif max is None:
return _clip_none(x, min, out, order, tei._maximum)
elif min is None:
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def copy(usm_ary, order="K"):
)
order = order[0].upper()
if not isinstance(usm_ary, dpt.usm_ndarray):
return TypeError(
raise TypeError(
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
)
copy_order = "C"
Expand Down
126 changes: 119 additions & 7 deletions dpctl/tests/test_tensor_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

import dpctl
import dpctl.tensor as dpt
from dpctl.tensor._type_utils import _can_cast
from dpctl.tensor._elementwise_common import _get_dtype
from dpctl.tensor._type_utils import (
_can_cast,
_strong_dtype_num_kind,
_weak_type_num_kind,
)
from dpctl.utils import ExecutionPlacementError

_all_dtypes = [
Expand Down Expand Up @@ -194,6 +199,15 @@ def test_clip_out_need_temporary():
dpt.clip(x[:6], 2, 3, out=x[-6:])
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)

x = dpt.arange(12, dtype="i4")
dpt.clip(x[:6], out=x[-6:])
expected = dpt.arange(6, dtype="i4")
assert dpt.all(x[:-6] == expected) and dpt.all(x[-6:] == expected)

x = dpt.ones(10, dtype="i4")
dpt.clip(x, out=x)
assert dpt.all(x == 1)

x = dpt.full(6, 3, dtype="i4")
a_min = dpt.full(10, 2, dtype="i4")
a_max = dpt.asarray(4, dtype="i4")
Expand Down Expand Up @@ -227,6 +241,21 @@ def test_clip_arg_validation():
with pytest.raises(TypeError):
dpt.clip(check, x1, x2)

with pytest.raises(ValueError):
dpt.clip(x1, check, x2)

with pytest.raises(ValueError):
dpt.clip(x1, check)

with pytest.raises(TypeError):
dpt.clip(x1, x1, x2, out=check)

with pytest.raises(TypeError):
dpt.clip(x1, x2, out=check)

with pytest.raises(TypeError):
dpt.clip(x1, out=check)


@pytest.mark.parametrize(
"dt1,dt2", [("i4", "i4"), ("i4", "i2"), ("i2", "i4"), ("i1", "i2")]
Expand Down Expand Up @@ -599,22 +628,40 @@ def test_clip_max_less_than_min():
assert dpt.all(res == 0)


def test_clip_minmax_weak_types():
@pytest.mark.parametrize("dt", ["?", "i4", "f4", "c8"])
def test_clip_minmax_weak_types(dt):
get_queue_or_skip()

x = dpt.zeros(10, dtype=dpt.bool)
x = dpt.zeros(10, dtype=dt)
min_list = [False, 0, 0.0, 0.0 + 0.0j]
max_list = [True, 1, 1.0, 1.0 + 0.0j]

for min_v, max_v in zip(min_list, max_list):
if isinstance(min_v, bool) and isinstance(max_v, bool):
y = dpt.clip(x, min_v, max_v)
assert isinstance(y, dpt.usm_ndarray)
st_dt = _strong_dtype_num_kind(dpt.dtype(dt))
wk_dt1 = _weak_type_num_kind(_get_dtype(min_v, x.sycl_device))
wk_dt2 = _weak_type_num_kind(_get_dtype(max_v, x.sycl_device))

if st_dt >= wk_dt1 and st_dt >= wk_dt2:
r = dpt.clip(x, min_v, max_v)
assert isinstance(r, dpt.usm_ndarray)
else:
with pytest.raises(ValueError):
dpt.clip(x, min_v, max_v)

if st_dt >= wk_dt1:
r = dpt.clip(x, min_v)
assert isinstance(r, dpt.usm_ndarray)

r = dpt.clip(x, None, min_v)
assert isinstance(r, dpt.usm_ndarray)
else:
with pytest.raises(ValueError):
dpt.clip(x, min_v)
with pytest.raises(ValueError):
dpt.clip(x, None, max_v)


def test_clip_max_weak_types():
def test_clip_max_weak_type_errors():
get_queue_or_skip()

x = dpt.zeros(10, dtype="i4")
Expand All @@ -626,6 +673,15 @@ def test_clip_max_weak_types():
with pytest.raises(ValueError):
dpt.clip(x, 2.5, m)

with pytest.raises(ValueError):
dpt.clip(x, 2.5)

with pytest.raises(ValueError):
dpt.clip(dpt.astype(x, "?"), 2)

with pytest.raises(ValueError):
dpt.clip(dpt.astype(x, "f4"), complex(2))


def test_clip_unaligned():
get_queue_or_skip()
Expand All @@ -636,3 +692,59 @@ def test_clip_unaligned():

expected = dpt.full(512, 2, dtype="i4")
assert dpt.all(dpt.clip(x[1:], a_min, a_max) == expected)


def test_clip_none_args():
get_queue_or_skip()

x = dpt.arange(10, dtype="i4")
r = dpt.clip(x)
assert dpt.all(x == r)


def test_clip_shape_errors():
get_queue_or_skip()

x = dpt.ones((4, 4), dtype="i4")
a_min = dpt.ones(5, dtype="i4")
a_max = dpt.ones(5, dtype="i4")

with pytest.raises(ValueError):
dpt.clip(x, a_min, a_max)

with pytest.raises(ValueError):
dpt.clip(x, a_min)

with pytest.raises(ValueError):
dpt.clip(x, 0, 1, out=a_min)

with pytest.raises(ValueError):
dpt.clip(x, 0, out=a_min)

with pytest.raises(ValueError):
dpt.clip(x, out=a_min)


def test_clip_compute_follows_data():
q1 = get_queue_or_skip()
q2 = get_queue_or_skip()

x = dpt.ones(10, dtype="i4", sycl_queue=q1)
a_min = dpt.ones(10, dtype="i4", sycl_queue=q2)
a_max = dpt.ones(10, dtype="i4", sycl_queue=q1)
res = dpt.empty_like(x, sycl_queue=q2)

with pytest.raises(ExecutionPlacementError):
dpt.clip(x, a_min, a_max)

with pytest.raises(ExecutionPlacementError):
dpt.clip(x, dpt.ones_like(x), a_max, out=res)

with pytest.raises(ExecutionPlacementError):
dpt.clip(x, a_min)

with pytest.raises(ExecutionPlacementError):
dpt.clip(x, None, a_max, out=res)

with pytest.raises(ExecutionPlacementError):
dpt.clip(x, out=res)

0 comments on commit 8f82fe1

Please sign in to comment.