Skip to content

Commit

Permalink
Enable __builtin_bit_cast for CUDA (#3066)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanTLavavej authored Sep 1, 2022
1 parent d503c92 commit af8adfa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
4 changes: 2 additions & 2 deletions stl/inc/complex
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ namespace _Float_multi_prec {
// multiplication

// round to 26 significant bits, ties toward zero
_NODISCARD _CONSTEXPR_BIT_CAST double _High_half(const double _Val) noexcept {
_NODISCARD constexpr double _High_half(const double _Val) noexcept {
const auto _Bits = _Bit_cast<unsigned long long>(_Val);
const auto _High_half_bits = (_Bits + 0x3ff'ffffULL) & 0xffff'ffff'f800'0000ULL;
return _Bit_cast<double>(_High_half_bits);
Expand All @@ -144,7 +144,7 @@ namespace _Float_multi_prec {
// 1) _Prod0 is _Xval^2 faithfully rounded
// 2) no internal overflow or underflow occurs
// violation of condition 1 could lead to relative error on the order of epsilon
_NODISCARD _CONSTEXPR_BIT_CAST double _Sqr_error_fallback(const double _Xval, const double _Prod0) noexcept {
_NODISCARD constexpr double _Sqr_error_fallback(const double _Xval, const double _Prod0) noexcept {
const double _Xhigh = _High_half(_Xval);
const double _Xlow = _Xval - _Xhigh;
return ((_Xhigh * _Xhigh - _Prod0) + 2.0 * _Xhigh * _Xlow) + _Xlow * _Xlow;
Expand Down
30 changes: 9 additions & 21 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ _STL_DISABLE_CLANG_WARNINGS
#endif // _USE_STD_VECTOR_ALGORITHMS
#endif // ^^^ no support for vector algorithms ^^^

#ifdef __CUDACC__
#define _CONSTEXPR_BIT_CAST inline
#else // ^^^ workaround ^^^ / vvv no workaround vvv
#define _CONSTEXPR_BIT_CAST constexpr
#endif // ^^^ no workaround ^^^

#if _USE_STD_VECTOR_ALGORITHMS
_EXTERN_C
// The "noalias" attribute tells the compiler optimizer that pointers going into these hand-vectorized algorithms
Expand Down Expand Up @@ -152,14 +146,8 @@ template <class _To, class _From,
enable_if_t<conjunction_v<bool_constant<sizeof(_To) == sizeof(_From)>, is_trivially_copyable<_To>,
is_trivially_copyable<_From>>,
int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST _To _Bit_cast(const _From& _Val) noexcept {
#ifdef __CUDACC__
_To _To_obj; // assumes default-init
_CSTD memcpy(_STD addressof(_To_obj), _STD addressof(_Val), sizeof(_To));
return _To_obj;
#else // ^^^ workaround ^^^ / vvv no workaround vvv
_NODISCARD constexpr _To _Bit_cast(const _From& _Val) noexcept {
return __builtin_bit_cast(_To, _Val);
#endif // ^^^ no workaround ^^^
}

template <class _Ty>
Expand Down Expand Up @@ -5114,7 +5102,7 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(_InIt _First, const _InIt _Last, c
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^ / vvv not _USE_STD_VECTOR_ALGORITHMS vvv
if constexpr (sizeof(_Iter_value_t<_InIt>) == 1) {
const auto _First_ptr = _To_address(_First);
const auto _Result = static_cast<remove_reference_t<_Iter_ref_t<_InIt>>*>(
const auto _Result = static_cast<remove_reference_t<_Iter_ref_t<_InIt>>*>(
_CSTD memchr(_First_ptr, static_cast<unsigned char>(_Val), static_cast<size_t>(_Last - _First)));
if constexpr (is_pointer_v<_InIt>) {
return _Result ? _Result : _Last;
Expand Down Expand Up @@ -6100,28 +6088,28 @@ struct _CXX17_DEPRECATE_ITERATOR_BASE_CLASS iterator { // base type for iterator
};

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST auto _Float_abs_bits(const _Ty& _Xx) {
_NODISCARD constexpr auto _Float_abs_bits(const _Ty& _Xx) {
using _Traits = _Floating_type_traits<_Ty>;
using _Uint_type = typename _Traits::_Uint_type;
const auto _Bits = _Bit_cast<_Uint_type>(_Xx);
return _Bits & ~_Traits::_Shifted_sign_mask;
}

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST _Ty _Float_abs(const _Ty _Xx) { // constexpr floating-point abs()
_NODISCARD constexpr _Ty _Float_abs(const _Ty _Xx) { // constexpr floating-point abs()
return _Bit_cast<_Ty>(_Float_abs_bits(_Xx));
}

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST _Ty _Float_copysign(const _Ty _Magnitude, const _Ty _Sign) { // constexpr copysign()
_NODISCARD constexpr _Ty _Float_copysign(const _Ty _Magnitude, const _Ty _Sign) { // constexpr copysign()
using _Traits = _Floating_type_traits<_Ty>;
using _Uint_type = typename _Traits::_Uint_type;
const auto _Signbit = _Bit_cast<_Uint_type>(_Sign) & _Traits::_Shifted_sign_mask;
return _Bit_cast<_Ty>(_Float_abs_bits(_Magnitude) | _Signbit);
}

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST bool _Is_nan(const _Ty _Xx) { // constexpr isnan()
_NODISCARD constexpr bool _Is_nan(const _Ty _Xx) { // constexpr isnan()
using _Traits = _Floating_type_traits<_Ty>;
return _Float_abs_bits(_Xx) > _Traits::_Shifted_exponent_mask;
}
Expand All @@ -6131,20 +6119,20 @@ _NODISCARD _CONSTEXPR_BIT_CAST bool _Is_nan(const _Ty _Xx) { // constexpr isnan(
// When the value is a 32-bit or 64-bit signaling NaN, the conversion to/from 80-bit raises FE_INVALID
// and turns it into a quiet NaN. This behavior is undesirable if we want to test for signaling NaNs.
template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST bool _Is_signaling_nan(const _Ty& _Xx) { // returns true if input is a signaling NaN
_NODISCARD constexpr bool _Is_signaling_nan(const _Ty& _Xx) { // returns true if input is a signaling NaN
using _Traits = _Floating_type_traits<_Ty>;
const auto _Abs_bits = _Float_abs_bits(_Xx);
return _Abs_bits > _Traits::_Shifted_exponent_mask && ((_Abs_bits & _Traits::_Special_nan_mantissa_mask) == 0);
}

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST bool _Is_inf(const _Ty _Xx) { // constexpr isinf()
_NODISCARD constexpr bool _Is_inf(const _Ty _Xx) { // constexpr isinf()
using _Traits = _Floating_type_traits<_Ty>;
return _Float_abs_bits(_Xx) == _Traits::_Shifted_exponent_mask;
}

template <class _Ty, enable_if_t<is_floating_point_v<_Ty>, int> = 0>
_NODISCARD _CONSTEXPR_BIT_CAST bool _Is_finite(const _Ty _Xx) { // constexpr isfinite()
_NODISCARD constexpr bool _Is_finite(const _Ty _Xx) { // constexpr isfinite()
using _Traits = _Floating_type_traits<_Ty>;
return _Float_abs_bits(_Xx) < _Traits::_Shifted_exponent_mask;
}
Expand Down

0 comments on commit af8adfa

Please sign in to comment.