Skip to content

Commit

Permalink
re-write dpnp.hypot (#1560)
Browse files Browse the repository at this point in the history
* re-write dpnp.hypot

* address comments

* fix docstring and precommit
  • Loading branch information
vtavana authored Sep 29, 2023
1 parent e3be611 commit 946ff08
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 87 deletions.
81 changes: 81 additions & 0 deletions dpnp/backend/extensions/vm/hypot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>

#include "common.hpp"
#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace vm
{
template <typename T>
sycl::event hypot_contig_impl(sycl::queue exec_q,
const std::int64_t n,
const char *in_a,
const char *in_b,
char *out_y,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);

return mkl_vm::hypot(exec_q,
n, // number of elements to be calculated
a, // pointer `a` containing 1st input vector of size n
b, // pointer `b` containing 2nd input vector of size n
y, // pointer `y` to the output vector of size n
depends);
}

template <typename fnT, typename T>
struct HypotContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename types::HypotOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return hypot_contig_impl<T>;
}
}
};
} // namespace vm
} // namespace ext
} // namespace backend
} // namespace dpnp
15 changes: 15 additions & 0 deletions dpnp/backend/extensions/vm/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,21 @@ struct FloorOutputType
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::hypot<T> function.
*
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
*/
template <typename T>
struct HypotOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
Expand Down
33 changes: 32 additions & 1 deletion dpnp/backend/extensions/vm/vm_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "cosh.hpp"
#include "div.hpp"
#include "floor.hpp"
#include "hypot.hpp"
#include "ln.hpp"
#include "mul.hpp"
#include "pow.hpp"
Expand Down Expand Up @@ -74,11 +75,12 @@ static unary_impl_fn_ptr_t atan_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t atan2_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t atanh_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t cosh_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
Expand Down Expand Up @@ -494,6 +496,35 @@ PYBIND11_MODULE(_vm_impl, m)
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
}

// BinaryUfunc: ==== Hypot(x1, x2) ====
{
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
vm_ext::HypotContigFactory>(
hypot_dispatch_vector);

auto hypot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
hypot_dispatch_vector);
};
m.def("_hypot", hypot_pyapi,
"Call `hypot` function from OneMKL VM library to compute element "
"by element hypotenuse of `x`",
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
py::arg("dst"), py::arg("depends") = py::list());

auto hypot_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
arrayT src2, arrayT dst) {
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
hypot_dispatch_vector);
};
m.def("_mkl_hypot_to_call", hypot_need_to_call_pyapi,
"Check input arguments to answer if `hypot` function from "
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
py::arg("dst"));
}

// UnaryUfunc: ==== Ln(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
Expand Down
2 changes: 0 additions & 2 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ enum class DPNPFuncName : size_t
DPNP_FN_GREATER_EQUAL_EXT, /**< Used in numpy.greater_equal() impl, requires
extra parameters */
DPNP_FN_HYPOT, /**< Used in numpy.hypot() impl */
DPNP_FN_HYPOT_EXT, /**< Used in numpy.hypot() impl, requires extra
parameters */
DPNP_FN_IDENTITY, /**< Used in numpy.identity() impl */
DPNP_FN_IDENTITY_EXT, /**< Used in numpy.identity() impl, requires extra
parameters */
Expand Down
13 changes: 0 additions & 13 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,19 +1514,6 @@ static void func_map_elemwise_2arg_3type_short_core(func_map_t &fmap)
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>}),
...);
((fmap[DPNPFuncName::DPNP_FN_HYPOT_EXT][FT1][FTs] =
{get_floating_res_type<FT1, FTs>(),
(void *)dpnp_hypot_c_ext<
func_type_map_t::find_type<get_floating_res_type<FT1, FTs>()>,
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>,
get_floating_res_type<FT1, FTs, std::false_type>(),
(void *)dpnp_hypot_c_ext<
func_type_map_t::find_type<
get_floating_res_type<FT1, FTs, std::false_type>()>,
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>}),
...);
((fmap[DPNPFuncName::DPNP_FN_MAXIMUM_EXT][FT1][FTs] =
{get_floating_res_type<FT1, FTs, std::true_type, std::true_type>(),
(void *)dpnp_maximum_c_ext<
Expand Down
4 changes: 0 additions & 4 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_FMOD_EXT
DPNP_FN_FULL
DPNP_FN_FULL_LIKE
DPNP_FN_HYPOT
DPNP_FN_HYPOT_EXT
DPNP_FN_IDENTITY
DPNP_FN_IDENTITY_EXT
DPNP_FN_INV
Expand Down Expand Up @@ -384,8 +382,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
"""
Mathematical functions
"""
cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_fmax(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_fmin(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
Expand Down
9 changes: 0 additions & 9 deletions dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ __all__ += [
"dpnp_fabs",
"dpnp_fmod",
"dpnp_gradient",
'dpnp_hypot',
"dpnp_fmax",
"dpnp_fmin",
"dpnp_modf",
Expand Down Expand Up @@ -273,14 +272,6 @@ cpdef utils.dpnp_descriptor dpnp_gradient(utils.dpnp_descriptor y1, int dx=1):
return result


cpdef utils.dpnp_descriptor dpnp_hypot(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
utils.dpnp_descriptor out=None,
object where=True):
return call_fptr_2in_1out_strides(DPNP_FN_HYPOT_EXT, x1_obj, x2_obj, dtype, out, where)


cpdef utils.dpnp_descriptor dpnp_fmax(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
Expand Down
61 changes: 61 additions & 0 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"dpnp_floor_divide",
"dpnp_greater",
"dpnp_greater_equal",
"dpnp_hypot",
"dpnp_imag",
"dpnp_invert",
"dpnp_isfinite",
Expand Down Expand Up @@ -1264,6 +1265,66 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"):
return dpnp_array._create_from_usm_ndarray(res_usm)


_hypot_docstring_ = """
hypot(x1, x2, out=None, order="K")
Calculates the hypotenuse for a right triangle with "legs" `x1_i` and `x2_i` of
input arrays `x1` and `x2`.
Args:
x1 (dpnp.ndarray):
First input array, expected to have a real-valued data type.
x2 (dpnp.ndarray):
Second input array, also expected to have a real-valued data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", None, optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
dpnp.ndarray:
An array containing the element-wise hypotenuse. The data type
of the returned array is determined by the Type Promotion Rules.
"""


def _call_hypot(src1, src2, dst, sycl_queue, depends=None):
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""

if depends is None:
depends = []

if vmi._mkl_hypot_to_call(sycl_queue, src1, src2, dst):
# call pybind11 extension for hypot() function from OneMKL VM
return vmi._hypot(sycl_queue, src1, src2, dst, depends)
return ti._hypot(src1, src2, dst, sycl_queue, depends)


hypot_func = BinaryElementwiseFunc(
"hypot",
ti._hypot_result_type,
_call_hypot,
_hypot_docstring_,
)


def dpnp_hypot(x1, x2, out=None, order="K"):
"""
Invokes hypot() function from pybind11 extension of OneMKL VM if possible.
Otherwise fully relies on dpctl.tensor implementation for hypot() function.
"""

# dpctl.tensor only works with usm_ndarray or scalar
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

res_usm = hypot_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
return dpnp_array._create_from_usm_ndarray(res_usm)


_imag_docstring = """
imag(x, out=None, order="K")
Expand Down
Loading

0 comments on commit 946ff08

Please sign in to comment.