Skip to content

Commit

Permalink
Implement dpnp.allclose() for a device without fp64 aspect (#1536)
Browse files Browse the repository at this point in the history
* Added support of dpnp.allclose() for a device without fp64 aspect

* Added tests for SYCL queue and USM type

* Handled a corner case with abs(MIN_INT)

* Increased test coverage

* Fixed typos

* Addressed review commets
  • Loading branch information
antonwolfy authored Aug 28, 2023
1 parent 5a2913f commit 0fd57d4
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 82 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ env:
test_random_state.py
test_sort.py
test_special.py
test_sycl_queue.py
test_umath.py
test_usm_type.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/logic_tests/test_comparison.py
third_party/cupy/logic_tests/test_truth.py
third_party/cupy/manipulation_tests/test_basic.py
third_party/cupy/manipulation_tests/test_join.py
Expand Down
123 changes: 89 additions & 34 deletions dpnp/backend/kernels/dpnp_krnl_logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
sycl::nd_range<1> gws(gws_range, lws_range);

auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto gr = nd_it.get_group();
auto gr = nd_it.get_sub_group();
const auto max_gr_size = gr.get_max_local_range()[0];
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +
Expand Down Expand Up @@ -127,8 +127,79 @@ DPCTLSyclEventRef (*dpnp_all_ext_c)(DPCTLSyclQueueRef,
const DPCTLEventVectorRef) =
dpnp_all_c<_DataType, _ResultType>;

template <typename _DataType1, typename _DataType2, typename _ResultType>
class dpnp_allclose_c_kernel;
template <typename _DataType1, typename _DataType2, typename _TolType>
class dpnp_allclose_kernel;

template <typename _DataType1, typename _DataType2, typename _TolType>
static sycl::event dpnp_allclose(sycl::queue &q,
const _DataType1 *array1,
const _DataType2 *array2,
bool *result,
const size_t size,
const _TolType rtol_val,
const _TolType atol_val)
{
sycl::event fill_event = q.fill(result, true, 1);
if (!size) {
return fill_event;
}

constexpr size_t lws = 64;
constexpr size_t vec_sz = 8;

auto gws_range =
sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
auto lws_range = sycl::range<1>(lws);
sycl::nd_range<1> gws(gws_range, lws_range);

auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto gr = nd_it.get_sub_group();
const auto max_gr_size = gr.get_max_local_range()[0];
const auto gr_size = gr.get_local_linear_range();
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +
gr.get_group_linear_id() * max_gr_size);
const size_t end = sycl::min(start + vec_sz * gr_size, size);

// each work-item iterates over "vec_sz" elements in the input arrays
bool partial = true;

for (size_t i = start + gr.get_local_linear_id(); i < end; i += gr_size)
{
if constexpr (std::is_floating_point_v<_DataType1> &&
std::is_floating_point_v<_DataType2>)
{
if (std::isinf(array1[i]) || std::isinf(array2[i])) {
partial &= (array1[i] == array2[i]);
continue;
}
}

// casting integeral to floating type to avoid bad behavior
// on abs(MIN_INT), which leads to undefined result
using _Arr2Type = std::conditional_t<std::is_integral_v<_DataType2>,
_TolType, _DataType2>;
_Arr2Type arr2 = static_cast<_Arr2Type>(array2[i]);

partial &= (std::abs(array1[i] - arr2) <=
(atol_val + rtol_val * std::abs(arr2)));
}
partial = sycl::all_of_group(gr, partial);

if (gr.leader() && (partial == false)) {
result[0] = false;
}
};

auto kernel_func = [&](sycl::handler &cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<
class dpnp_allclose_kernel<_DataType1, _DataType2, _TolType>>(
gws, kernel_parallel_for_func);
};

return q.submit(kernel_func);
}

template <typename _DataType1, typename _DataType2, typename _ResultType>
DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
Expand All @@ -140,6 +211,9 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
double atol_val,
const DPCTLEventVectorRef dep_event_vec_ref)
{
static_assert(std::is_same_v<_ResultType, bool>,
"Boolean result type is required");

// avoid warning unused variable
(void)dep_event_vec_ref;

Expand All @@ -152,40 +226,21 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType1> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_DataType2> input2_ptr(q_ref, array2_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType1 *array1 = input1_ptr.get_ptr();
const _DataType2 *array2 = input2_ptr.get_ptr();
_ResultType *result = result1_ptr.get_ptr();

result[0] = true;
const _DataType1 *array1 = static_cast<const _DataType1 *>(array1_in);
const _DataType2 *array2 = static_cast<const _DataType2 *>(array2_in);
bool *result = static_cast<bool *>(result1);

if (!size) {
return event_ref;
if (q.get_device().has(sycl::aspect::fp64)) {
event =
dpnp_allclose(q, array1, array2, result, size, rtol_val, atol_val);
}
else {
float rtol = static_cast<float>(rtol_val);
float atol = static_cast<float>(atol_val);
event = dpnp_allclose(q, array1, array2, result, size, rtol, atol);
}

sycl::range<1> gws(size);
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
size_t i = global_id[0];

if (std::abs(array1[i] - array2[i]) >
(atol_val + rtol_val * std::abs(array2[i])))
{
result[0] = false;
}
};

auto kernel_func = [&](sycl::handler &cgh) {
cgh.parallel_for<
class dpnp_allclose_c_kernel<_DataType1, _DataType2, _ResultType>>(
gws, kernel_parallel_for_func);
};

event = q.submit(kernel_func);

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand Down Expand Up @@ -269,7 +324,7 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
sycl::nd_range<1> gws(gws_range, lws_range);

auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto gr = nd_it.get_group();
auto gr = nd_it.get_sub_group();
const auto max_gr_size = gr.get_max_local_range()[0];
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +
Expand Down
25 changes: 24 additions & 1 deletion dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"get_normalized_queue_device",
"get_usm_ndarray",
"get_usm_ndarray_or_scalar",
"is_supported_array_or_scalar",
"is_supported_array_type",
]

Expand Down Expand Up @@ -453,14 +454,36 @@ def get_usm_ndarray_or_scalar(a):
return a if isscalar(a) else get_usm_ndarray(a)


def is_supported_array_or_scalar(a):
"""
Return ``True`` if `a` is a scalar or an array of either
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray` type,
``False`` otherwise.
Parameters
----------
a : {scalar, dpnp_array, usm_ndarray}
An input scalar or an array to check the type of.
Returns
-------
out : bool
``True`` if input `a` is a scalar or an array of supported type,
``False`` otherwise.
"""

return isscalar(a) or is_supported_array_type(a)


def is_supported_array_type(a):
"""
Return ``True`` if an array of either type :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray` type, ``False`` otherwise.
Parameters
----------
a : array
a : {dpnp_array, usm_ndarray}
An input array to check the type.
Returns
Expand Down
5 changes: 3 additions & 2 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,15 @@ def outer(x1, x2, out=None):
[1, 2, 3]])
"""

x1_is_scalar = dpnp.isscalar(x1)
x2_is_scalar = dpnp.isscalar(x2)

if x1_is_scalar and x2_is_scalar:
pass
elif not (x1_is_scalar or dpnp.is_supported_array_type(x1)):
elif not dpnp.is_supported_array_or_scalar(x1):
pass
elif not (x2_is_scalar or dpnp.is_supported_array_type(x2)):
elif not dpnp.is_supported_array_or_scalar(x2):
pass
else:
x1_in = (
Expand Down
86 changes: 69 additions & 17 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,42 +152,94 @@ def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
)


def allclose(x1, x2, rtol=1.0e-5, atol=1.0e-8, **kwargs):
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
"""
Returns True if two arrays are element-wise equal within a tolerance.
For full documentation refer to :obj:`numpy.allclose`.
Returns
-------
out : dpnp.ndarray
A boolean 0-dim array. If its value is ``True``,
two arrays are element-wise equal within a tolerance.
Limitations
-----------
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar.
Parameters `a` and `b` are supported either as :class:`dpnp.ndarray`,
:class:`dpctl.tensor.usm_ndarray` or scalars, but both `a` and `b`
can not be scalars at the same time.
Keyword argument `kwargs` is currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameters `rtol` and `atol` are supported as scalars. Otherwise
``TypeError`` exeption will be raised.
Input array data types are limited by supported integer and
floating DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.isclose` : Test whether two arrays are element-wise equal.
:obj:`dpnp.all` : Test whether all elements evaluate to True.
:obj:`dpnp.any` : Test whether any element evaluates to True.
:obj:`dpnp.equal` : Return (x1 == x2) element-wise.
Examples
--------
>>> import dpnp as np
>>> np.allclose([1e10,1e-7], [1.00001e10,1e-8])
>>> False
>>> a = np.array([1e10, 1e-7])
>>> b = np.array([1.00001e10, 1e-8])
>>> np.allclose(a, b)
array([False])
>>> a = np.array([1.0, np.nan])
>>> b = np.array([1.0, np.nan])
>>> np.allclose(a, b)
array([False])
>>> a = np.array([1.0, np.inf])
>>> b = np.array([1.0, np.inf])
>>> np.allclose(a, b)
array([ True])
"""

rtol_is_scalar = dpnp.isscalar(rtol)
atol_is_scalar = dpnp.isscalar(atol)
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
if dpnp.isscalar(a) and dpnp.isscalar(b):
# at least one of inputs has to be an array
pass
elif not (
dpnp.is_supported_array_or_scalar(a)
and dpnp.is_supported_array_or_scalar(b)
):
pass
elif kwargs:
pass
else:
if not dpnp.isscalar(rtol):
raise TypeError(
"An argument `rtol` must be a scalar, but got {}".format(
type(rtol)
)
)
elif not dpnp.isscalar(atol):
raise TypeError(
"An argument `atol` must be a scalar, but got {}".format(
type(atol)
)
)

if x1_desc and x2_desc and not kwargs:
if not rtol_is_scalar or not atol_is_scalar:
pass
else:
result_obj = dpnp_allclose(x1_desc, x2_desc, rtol, atol).get_pyobj()
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
if dpnp.isscalar(a):
a = dpnp.full_like(b, fill_value=a)
elif dpnp.isscalar(b):
b = dpnp.full_like(a, fill_value=b)
elif a.shape != b.shape:
a, b = dpt.broadcast_arrays(a.get_array(), b.get_array())

return result
a_desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False)
b_desc = dpnp.get_dpnp_descriptor(b, copy_when_nondefault_queue=False)
if a_desc and b_desc:
return dpnp_allclose(a_desc, b_desc, rtol, atol).get_pyobj()

return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)
return call_origin(numpy.allclose, a, b, rtol=rtol, atol=atol, **kwargs)


def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
Expand Down
6 changes: 1 addition & 5 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,7 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_array_scalar
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_finite
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite_equal_nan
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_min_int

tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal
Expand Down
6 changes: 1 addition & 5 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,7 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_array_scalar
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_finite
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite_equal_nan
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_min_int

tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal
Expand Down
4 changes: 2 additions & 2 deletions tests/skipped_tests_gpu_no_fp64.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ tests/test_sycl_queue.py::test_array_creation[opencl:gpu:0-arange-arg0-kwargs0]
tests/test_sycl_queue.py::test_array_creation[level_zero:gpu:0-arange-arg0-kwargs0]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-gradient-data10]
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-gradient-data10]
tests/test_sycl_queue.py::test_2in_1out[opencl:gpu:0-power-data112-data212]
tests/test_sycl_queue.py::test_2in_1out[level_zero:gpu:0-power-data112-data212]
tests/test_sycl_queue.py::test_2in_1out[opencl:gpu:0-power-data113-data213]
tests/test_sycl_queue.py::test_2in_1out[level_zero:gpu:0-power-data113-data213]
tests/test_sycl_queue.py::test_out_2in_1out[opencl:gpu:0-power-data19-data29]
tests/test_sycl_queue.py::test_out_2in_1out[level_zero:gpu:0-power-data19-data29]
tests/test_sycl_queue.py::test_eig[opencl:gpu:0]
Expand Down
Loading

0 comments on commit 0fd57d4

Please sign in to comment.