Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework mdspan concept emulation #2213

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions libcudacxx/include/cuda/std/__mdspan/default_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ struct default_accessor

__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr default_accessor() noexcept = default;

__MDSPAN_TEMPLATE_REQUIRES(class _OtherElementType,
/* requires */ (_CCCL_TRAIT(is_convertible, _OtherElementType (*)[], element_type (*)[])))
__MDSPAN_INLINE_FUNCTION
constexpr default_accessor(default_accessor<_OtherElementType>) noexcept {}
_LIBCUDACXX_TEMPLATE(class _OtherElementType)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(is_convertible, _OtherElementType (*)[], element_type (*)[]))
__MDSPAN_INLINE_FUNCTION constexpr default_accessor(default_accessor<_OtherElementType>) noexcept {}

__MDSPAN_INLINE_FUNCTION
constexpr data_handle_type offset(data_handle_type __p, size_t __i) const noexcept
Expand Down
95 changes: 41 additions & 54 deletions libcudacxx/include/cuda/std/__mdspan/extents.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,13 @@ class extents
__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr extents() noexcept = default;

// Converting constructor
__MDSPAN_TEMPLATE_REQUIRES(
class _OtherIndexType,
size_t... _OtherExtents,
/* requires */
(
/* multi-stage check to protect from invalid pack expansion when sizes don't match? */
decltype(__detail::__check_compatible_extents(
integral_constant<bool, sizeof...(_Extents) == sizeof...(_OtherExtents)>{},
_CUDA_VSTD::integer_sequence<size_t, _Extents...>{},
_CUDA_VSTD::integer_sequence<size_t, _OtherExtents...>{}))::value))
_LIBCUDACXX_TEMPLATE(class _OtherIndexType, size_t... _OtherExtents)
_LIBCUDACXX_REQUIRES(
/* multi-stage check to protect from invalid pack expansion when sizes don't match? */
(decltype(__detail::__check_compatible_extents(
integral_constant<bool, sizeof...(_Extents) == sizeof...(_OtherExtents)>{},
_CUDA_VSTD::integer_sequence<size_t, _Extents...>{},
_CUDA_VSTD::integer_sequence<size_t, _OtherExtents...>{}))::value))
__MDSPAN_INLINE_FUNCTION
__MDSPAN_CONDITIONAL_EXPLICIT(
(((_Extents != dynamic_extent) && (_OtherExtents == dynamic_extent)) || ...)
Expand Down Expand Up @@ -287,23 +284,23 @@ class extents
}

# ifdef __NVCC__
__MDSPAN_TEMPLATE_REQUIRES(
class... _Integral,
/* requires */ (
// TODO: check whether the other version works with newest NVCC, doesn't with 11.4
// NVCC seems to pick up rank_dynamic from the wrong extents type???
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
&& __MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */) &&
// NVCC chokes on the fold thingy here so wrote the workaround
((sizeof...(_Integral) == __detail::__count_dynamic_extents<_Extents...>::val)
|| (sizeof...(_Integral) == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class... _Integral)
_LIBCUDACXX_REQUIRES(
// TODO: check whether the other version works with newest NVCC, doesn't with 11.4
// NVCC seems to pick up rank_dynamic from the wrong extents type???
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
_LIBCUDACXX_AND __MDSPAN_FOLD_AND(
_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */) _LIBCUDACXX_AND
// NVCC chokes on the fold thingy here so wrote the workaround
((sizeof...(_Integral) == __detail::__count_dynamic_extents<_Extents...>::val)
|| (sizeof...(_Integral) == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class... _Integral,
/* requires */ (
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
&& __MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */)
&& ((sizeof...(_Integral) == rank_dynamic()) || (sizeof...(_Integral) == rank()))))
_LIBCUDACXX_TEMPLATE(class... _Integral)
_LIBCUDACXX_REQUIRES(
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
_LIBCUDACXX_AND __MDSPAN_FOLD_AND(
_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */)
_LIBCUDACXX_AND((sizeof...(_Integral) == rank_dynamic()) || (sizeof...(_Integral) == rank())))
# endif
__MDSPAN_INLINE_FUNCTION
explicit constexpr extents(_Integral... __exts) noexcept
Expand Down Expand Up @@ -337,21 +334,16 @@ class extents
# ifdef __NVCC__
// NVCC seems to pick up rank_dynamic from the wrong extents type???
// NVCC chokes on the fold thingy here so wrote the workaround
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& ((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(
_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& (_Np == rank() || _Np == rank_dynamic())))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND(_Np == rank() || _Np == rank_dynamic()))
# endif
__MDSPAN_CONDITIONAL_EXPLICIT(_Np != rank_dynamic())
__MDSPAN_INLINE_FUNCTION
Expand Down Expand Up @@ -386,21 +378,16 @@ class extents
# ifdef __NVCC__
// NVCC seems to pick up rank_dynamic from the wrong extents type???
// NVCC chokes on the fold thingy here so wrote the workaround
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& ((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(
_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& (_Np == rank() || _Np == rank_dynamic())))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND(_Np == rank() || _Np == rank_dynamic()))
# endif
__MDSPAN_CONDITIONAL_EXPLICIT(_Np != rank_dynamic())
__MDSPAN_INLINE_FUNCTION
Expand Down
27 changes: 13 additions & 14 deletions libcudacxx/include/cuda/std/__mdspan/layout_left.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class layout_left::mapping
: __extents(__exts)
{}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -135,9 +135,9 @@ class layout_left::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
&& (extents_type::rank() <= 1)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
_LIBCUDACXX_AND(extents_type::rank() <= 1))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -150,8 +150,8 @@ class layout_left::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((extents_type::rank() > 0))
__MDSPAN_INLINE_FUNCTION constexpr mapping(
layout_stride::mapping<_OtherExtents> const& __other) // NOLINT(google-explicit-constructor)
Expand Down Expand Up @@ -190,11 +190,10 @@ class layout_left::mapping

//--------------------------------------------------------------------------------

__MDSPAN_TEMPLATE_REQUIRES(
class... _Indices,
/* requires */ ((sizeof...(_Indices) == extents_type::rank())
&& __MDSPAN_FOLD_AND((_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices)))))
_LIBCUDACXX_TEMPLATE(class... _Indices)
_LIBCUDACXX_REQUIRES((sizeof...(_Indices) == extents_type::rank()) _LIBCUDACXX_AND __MDSPAN_FOLD_AND(
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices))))
_CCCL_HOST_DEVICE constexpr index_type operator()(_Indices... __idxs) const noexcept
{
// Immediately cast incoming indices to `index_type`
Expand Down Expand Up @@ -227,8 +226,8 @@ class layout_left::mapping
return true;
}

__MDSPAN_TEMPLATE_REQUIRES(class _Ext = _Extents,
/* requires */ (_Ext::rank() > 0))
_LIBCUDACXX_TEMPLATE(class _Ext = _Extents)
_LIBCUDACXX_REQUIRES((_Ext::rank() > 0))
__MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type __i) const noexcept
{
Expand Down
27 changes: 13 additions & 14 deletions libcudacxx/include/cuda/std/__mdspan/layout_right.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class layout_right::mapping
: __extents(__exts)
{}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -140,9 +140,9 @@ class layout_right::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
&& (extents_type::rank() <= 1)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
_LIBCUDACXX_AND(extents_type::rank() <= 1))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -155,8 +155,8 @@ class layout_right::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((extents_type::rank() > 0))
__MDSPAN_INLINE_FUNCTION constexpr mapping(
layout_stride::mapping<_OtherExtents> const& __other) // NOLINT(google-explicit-constructor)
Expand Down Expand Up @@ -195,11 +195,10 @@ class layout_right::mapping

//--------------------------------------------------------------------------------

__MDSPAN_TEMPLATE_REQUIRES(
class... _Indices,
/* requires */ ((sizeof...(_Indices) == extents_type::rank())
&& __MDSPAN_FOLD_AND((_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices)))))
_LIBCUDACXX_TEMPLATE(class... _Indices)
_LIBCUDACXX_REQUIRES((sizeof...(_Indices) == extents_type::rank()) _LIBCUDACXX_AND __MDSPAN_FOLD_AND(
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices))))
_CCCL_HOST_DEVICE constexpr index_type operator()(_Indices... __idxs) const noexcept
{
return __compute_offset(__rank_count<0, extents_type::rank()>(), static_cast<index_type>(__idxs)...);
Expand Down Expand Up @@ -230,8 +229,8 @@ class layout_right::mapping
return true;
}

__MDSPAN_TEMPLATE_REQUIRES(class _Ext = _Extents,
/* requires */ (_Ext::rank() > 0))
_LIBCUDACXX_TEMPLATE(class _Ext = _Extents)
_LIBCUDACXX_REQUIRES((_Ext::rank() > 0))
__MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type __i) const noexcept
{
Expand Down
Loading
Loading