From f4bd89735411f31f185871b6572ab74fd4bafc90 Mon Sep 17 00:00:00 2001 From: Christian Trott Date: Sun, 5 Feb 2023 09:09:29 -0700 Subject: [PATCH 1/2] Add missing mdspan defines in mdarray --- include/experimental/__p1684_bits/mdarray.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/experimental/__p1684_bits/mdarray.hpp b/include/experimental/__p1684_bits/mdarray.hpp index ea5b9cf8..2628e11f 100644 --- a/include/experimental/__p1684_bits/mdarray.hpp +++ b/include/experimental/__p1684_bits/mdarray.hpp @@ -72,6 +72,8 @@ class mdarray { using container_type = Container; using mapping_type = typename layout_type::template mapping; using element_type = ElementType; + using mdspan_type = mdspan; + using const_mdspan_type = mdspan; using value_type = remove_cv_t; using index_type = typename Extents::index_type; using size_type = typename Extents::size_type; From 0fbf36814efb3e3a17408c081bc21c84ab9680c5 Mon Sep 17 00:00:00 2001 From: Christian Trott Date: Sun, 5 Feb 2023 10:20:56 -0700 Subject: [PATCH 2/2] Add missing mdarray to mdspan conversion functions --- include/experimental/__p1684_bits/mdarray.hpp | 53 +++++++++++++++- tests/CMakeLists.txt | 1 + tests/test_mdarray_to_mdspan.cpp | 63 +++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 tests/test_mdarray_to_mdspan.cpp diff --git a/include/experimental/__p1684_bits/mdarray.hpp b/include/experimental/__p1684_bits/mdarray.hpp index 2628e11f..c84fd78a 100644 --- a/include/experimental/__p1684_bits/mdarray.hpp +++ b/include/experimental/__p1684_bits/mdarray.hpp @@ -407,7 +407,7 @@ class mdarray { MDSPAN_INLINE_FUNCTION static constexpr rank_type rank_dynamic() noexcept { return extents_type::rank_dynamic(); } MDSPAN_INLINE_FUNCTION static constexpr size_t static_extent(size_t r) noexcept { return extents_type::static_extent(r); } - MDSPAN_INLINE_FUNCTION constexpr extents_type extents() const noexcept { return map_.extents(); }; + MDSPAN_INLINE_FUNCTION constexpr const extents_type& extents() const noexcept { return map_.extents(); }; MDSPAN_INLINE_FUNCTION constexpr index_type extent(size_t r) const noexcept { return map_.extents().extent(r); }; MDSPAN_INLINE_FUNCTION constexpr index_type size() const noexcept { // return __impl::__size(*this); @@ -422,12 +422,61 @@ class mdarray { MDSPAN_INLINE_FUNCTION static constexpr bool is_always_exhaustive() noexcept { return mapping_type::is_always_exhaustive(); }; MDSPAN_INLINE_FUNCTION static constexpr bool is_always_strided() noexcept { return mapping_type::is_always_strided(); }; - MDSPAN_INLINE_FUNCTION constexpr mapping_type mapping() const noexcept { return map_; }; + MDSPAN_INLINE_FUNCTION constexpr const mapping_type& mapping() const noexcept { return map_; }; MDSPAN_INLINE_FUNCTION constexpr bool is_unique() const noexcept { return map_.is_unique(); }; MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return map_.is_exhaustive(); }; MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return map_.is_strided(); }; MDSPAN_INLINE_FUNCTION constexpr index_type stride(size_t r) const { return map_.stride(r); }; + // Converstion to mdspan + MDSPAN_TEMPLATE_REQUIRES( + class OtherElementType, class OtherExtents, + class OtherLayoutType, class OtherAccessorType, + /* requires */ ( + _MDSPAN_TRAIT(is_assignable, mdspan_type, + mdspan) + ) + ) + constexpr operator mdspan () { + return mdspan_type(data(), map_); + } + + MDSPAN_TEMPLATE_REQUIRES( + class OtherElementType, class OtherExtents, + class OtherLayoutType, class OtherAccessorType, + /* requires */ ( + _MDSPAN_TRAIT(is_assignable, const_mdspan_type, + mdspan) + ) + ) + constexpr operator mdspan () const { + return const_mdspan_type(data(), map_); + } + + MDSPAN_TEMPLATE_REQUIRES( + class OtherAccessorType = default_accessor, + /* requires */ ( + _MDSPAN_TRAIT(is_assignable, mdspan_type, + mdspan) + ) + ) + constexpr mdspan + to_mdspan(const OtherAccessorType& a = default_accessor()) { + return mdspan(data(), map_, a); + } + + MDSPAN_TEMPLATE_REQUIRES( + class OtherAccessorType = default_accessor, + /* requires */ ( + _MDSPAN_TRAIT(is_assignable, const_mdspan_type, + mdspan) + ) + ) + constexpr mdspan + to_mdspan(const OtherAccessorType& a = default_accessor()) const { + return mdspan(data(), map_, a); + } + private: mapping_type map_; container_type ctr_; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index afd2f3b1..61b663f1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -64,4 +64,5 @@ endif() # both of those don't work yet since its using vector if(NOT MDSPAN_ENABLE_CUDA AND NOT MDSPAN_ENABLE_HIP) mdspan_add_test(test_mdarray_ctors) +mdspan_add_test(test_mdarray_to_mdspan) endif() diff --git a/tests/test_mdarray_to_mdspan.cpp b/tests/test_mdarray_to_mdspan.cpp new file mode 100644 index 00000000..43d8d4bb --- /dev/null +++ b/tests/test_mdarray_to_mdspan.cpp @@ -0,0 +1,63 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include + +#include + + +namespace stdex = std::experimental; +_MDSPAN_INLINE_VARIABLE constexpr auto dyn = stdex::dynamic_extent; + +template +struct MDArrayToMDSpanOperatorTest { + using mdspan_t = MDSpan; + using c_mdspan_t = stdex::mdspan; + static void test_check(mdspan_t mds, MDArray& mda) { + ASSERT_EQ(mds.data_handle(), mda.data()); + ASSERT_EQ(mds.extents(), mda.extents()); + ASSERT_EQ(mds.mapping(), mda.mapping()); + } + static void test_check_const(c_mdspan_t mds, const MDArray& mda) { + ASSERT_EQ(mds.data_handle(), mda.data()); + ASSERT_EQ(mds.extents(), mda.extents()); + ASSERT_EQ(mds.mapping(), mda.mapping()); + } + template + static void test(ConstrArgs ... args) { + MDArray a(args...); + test_check(a, a); + test_check(a.to_mdspan(), a); + const MDArray& c_a = a; + test_check_const(c_a, a); + test_check_const(c_a.to_mdspan(), a); + } +}; + +TEST(TestMDArray,mdarray_to_mdspan) { + MDArrayToMDSpanOperatorTest>, + stdex::mdarray>>::test( + 100 + ); + MDArrayToMDSpanOperatorTest>, + stdex::mdarray>>::test( + 100 + ); +} + +