-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda::std::complex specializations for half and bfloat #1140
Conversation
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
65e6f36
to
744f2d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a great job working around the quirks of those types 👏
I would love to move some of the traits around (e.g. into is_floating_point.h
) and importantly add a proper named define that one can grep for.
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
libcudacxx/test/libcudacxx/std/numerics/complex.number/cmplx.over/real.pass.cpp
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general, thanks for working on this @griwes !
I think I missed some static_asserts for the size and alignment of complex half and bfloat, do we have these somewhere? Thanks!
c2d87c2
to
add3d52
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering whether we should just keep all the _LIBCUDACXX_HAS_NO_NVFP16
in place and define it conditionally for host
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
Specifically: * disable BF16 when FP16 is disabled, since the former includes the latter; * disable both when the toolkit version is lower than 12.2, since 12.2 is when both types got the host versions of a lot of functions we need to make useful heterogeneous things with them; * disable both in host-only TU, as there's no easy way I could find to detect the condition above. I've included an opt-in macro for asserting that the headers (if available) are from a sufficiently new CTK, will add that to docs in a later commit.
f2893fa
to
8121bba
Compare
NVCC is spewing code that makes various versions of clang unhappy about a deprecated implicit copy constructor of a lambda wrapper, so just work around that by not using one.
libcudacxx/include/cuda/std/detail/libcxx/include/__cuda/cmath_nvfp16.h
Outdated
Show resolved
Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h
Outdated
Show resolved
Hide resolved
Co-authored-by: Wesley Maxey <[email protected]>
Note: As discussed offline, local tests show that at least on sm86/89 we need this patch for performance reasons. I haven't had a chance to test on sm70/80/90, though. diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/complex b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
index 3ba249779..416c0e71d 100644
--- a/libcudacxx/include/cuda/std/detail/libcxx/include/complex
+++ b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
@@ -1702,6 +1702,16 @@ atanh(const complex<_Tp>& __x)
return complex<_Tp>(__constexpr_copysign(__z.real(), __x.real()), __constexpr_copysign(__z.imag(), __x.imag()));
}
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atanh(const complex<__half>& __x)
+{
+ complex<float> __temp(__x);
+ __temp = _CUDA_VSTD::atanh(__temp);
+ return complex<__half>(__temp.real(), __temp.imag());
+}
+
// sinh
template<class _Tp>
@@ -1815,6 +1825,16 @@ atan(const complex<_Tp>& __x)
return complex<_Tp>(__z.imag(), -__z.real());
}
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atan(const complex<__half>& __x)
+{
+ complex<float> __temp(__x);
+ __temp = _CUDA_VSTD::atan(__temp);
+ return complex<__half>(__temp.real(), __temp.imag());
+}
+
// sin
template<class _Tp> |
@leofang I added some workarounds for |
Description
Resolves #1139
Introduce specializations of
complex<T>
for half and bfloat.Checklist
Additional checklist