Skip to content

Commit

Permalink
Merge pull request #240 from crtrott/mdarray-to-mdspan
Browse files Browse the repository at this point in the history
Mdarray to mdspan
  • Loading branch information
crtrott authored Feb 9, 2023
2 parents 0a1ce8c + 0fbf368 commit a906275
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 2 deletions.
55 changes: 53 additions & 2 deletions include/experimental/__p1684_bits/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class mdarray {
using container_type = Container;
using mapping_type = typename layout_type::template mapping<extents_type>;
using element_type = ElementType;
using mdspan_type = mdspan<element_type, extents_type, layout_type>;
using const_mdspan_type = mdspan<const element_type, extents_type, layout_type>;
using value_type = remove_cv_t<element_type>;
using index_type = typename Extents::index_type;
using size_type = typename Extents::size_type;
Expand Down Expand Up @@ -405,7 +407,7 @@ class mdarray {
MDSPAN_INLINE_FUNCTION static constexpr rank_type rank_dynamic() noexcept { return extents_type::rank_dynamic(); }
MDSPAN_INLINE_FUNCTION static constexpr size_t static_extent(size_t r) noexcept { return extents_type::static_extent(r); }

MDSPAN_INLINE_FUNCTION constexpr extents_type extents() const noexcept { return map_.extents(); };
MDSPAN_INLINE_FUNCTION constexpr const extents_type& extents() const noexcept { return map_.extents(); };
MDSPAN_INLINE_FUNCTION constexpr index_type extent(size_t r) const noexcept { return map_.extents().extent(r); };
MDSPAN_INLINE_FUNCTION constexpr index_type size() const noexcept {
// return __impl::__size(*this);
Expand All @@ -420,12 +422,61 @@ class mdarray {
MDSPAN_INLINE_FUNCTION static constexpr bool is_always_exhaustive() noexcept { return mapping_type::is_always_exhaustive(); };
MDSPAN_INLINE_FUNCTION static constexpr bool is_always_strided() noexcept { return mapping_type::is_always_strided(); };

MDSPAN_INLINE_FUNCTION constexpr mapping_type mapping() const noexcept { return map_; };
MDSPAN_INLINE_FUNCTION constexpr const mapping_type& mapping() const noexcept { return map_; };
MDSPAN_INLINE_FUNCTION constexpr bool is_unique() const noexcept { return map_.is_unique(); };
MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return map_.is_exhaustive(); };
MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return map_.is_strided(); };
MDSPAN_INLINE_FUNCTION constexpr index_type stride(size_t r) const { return map_.stride(r); };

// Converstion to mdspan
MDSPAN_TEMPLATE_REQUIRES(
class OtherElementType, class OtherExtents,
class OtherLayoutType, class OtherAccessorType,
/* requires */ (
_MDSPAN_TRAIT(is_assignable, mdspan_type,
mdspan<OtherElementType, OtherExtents, OtherLayoutType, OtherAccessorType>)
)
)
constexpr operator mdspan<OtherElementType, OtherExtents, OtherLayoutType, OtherAccessorType> () {
return mdspan_type(data(), map_);
}

MDSPAN_TEMPLATE_REQUIRES(
class OtherElementType, class OtherExtents,
class OtherLayoutType, class OtherAccessorType,
/* requires */ (
_MDSPAN_TRAIT(is_assignable, const_mdspan_type,
mdspan<OtherElementType, OtherExtents, OtherLayoutType, OtherAccessorType>)
)
)
constexpr operator mdspan<OtherElementType, OtherExtents, OtherLayoutType, OtherAccessorType> () const {
return const_mdspan_type(data(), map_);
}

MDSPAN_TEMPLATE_REQUIRES(
class OtherAccessorType = default_accessor<element_type>,
/* requires */ (
_MDSPAN_TRAIT(is_assignable, mdspan_type,
mdspan<element_type, extents_type, layout_type, OtherAccessorType>)
)
)
constexpr mdspan<element_type, extents_type, layout_type, OtherAccessorType>
to_mdspan(const OtherAccessorType& a = default_accessor<element_type>()) {
return mdspan<element_type, extents_type, layout_type, OtherAccessorType>(data(), map_, a);
}

MDSPAN_TEMPLATE_REQUIRES(
class OtherAccessorType = default_accessor<const element_type>,
/* requires */ (
_MDSPAN_TRAIT(is_assignable, const_mdspan_type,
mdspan<const element_type, extents_type, layout_type, OtherAccessorType>)
)
)
constexpr mdspan<const element_type, extents_type, layout_type, OtherAccessorType>
to_mdspan(const OtherAccessorType& a = default_accessor<const element_type>()) const {
return mdspan<const element_type, extents_type, layout_type, OtherAccessorType>(data(), map_, a);
}

private:
mapping_type map_;
container_type ctr_;
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ endif()
# both of those don't work yet since its using vector
if(NOT MDSPAN_ENABLE_CUDA AND NOT MDSPAN_ENABLE_HIP)
mdspan_add_test(test_mdarray_ctors)
mdspan_add_test(test_mdarray_to_mdspan)
endif()
63 changes: 63 additions & 0 deletions tests/test_mdarray_to_mdspan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#include <experimental/mdarray>
#include <vector>

#include <gtest/gtest.h>


namespace stdex = std::experimental;
_MDSPAN_INLINE_VARIABLE constexpr auto dyn = stdex::dynamic_extent;

template<class MDSpan, class MDArray>
struct MDArrayToMDSpanOperatorTest {
using mdspan_t = MDSpan;
using c_mdspan_t = stdex::mdspan<const typename mdspan_t::element_type,
typename mdspan_t::extents_type,
typename mdspan_t::layout_type>;
static void test_check(mdspan_t mds, MDArray& mda) {
ASSERT_EQ(mds.data_handle(), mda.data());
ASSERT_EQ(mds.extents(), mda.extents());
ASSERT_EQ(mds.mapping(), mda.mapping());
}
static void test_check_const(c_mdspan_t mds, const MDArray& mda) {
ASSERT_EQ(mds.data_handle(), mda.data());
ASSERT_EQ(mds.extents(), mda.extents());
ASSERT_EQ(mds.mapping(), mda.mapping());
}
template<class ... ConstrArgs>
static void test(ConstrArgs ... args) {
MDArray a(args...);
test_check(a, a);
test_check(a.to_mdspan(), a);
const MDArray& c_a = a;
test_check_const(c_a, a);
test_check_const(c_a.to_mdspan(), a);
}
};

TEST(TestMDArray,mdarray_to_mdspan) {
MDArrayToMDSpanOperatorTest<stdex::mdspan <int, stdex::extents<int, dyn>>,
stdex::mdarray<int, stdex::extents<int, dyn>>>::test(
100
);
MDArrayToMDSpanOperatorTest<stdex::mdspan <int, stdex::extents<int, dyn>>,
stdex::mdarray<int, stdex::extents<int, 100>>>::test(
100
);
}


0 comments on commit a906275

Please sign in to comment.